Giter Site home page Giter Site logo

dragen1860 / maml-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
196.0 10.0 69.0 219 KB

Faster and elegant TensorFlow Implementation of paper: Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

Python 100.00%
machine-learning metalearning tensorflow

maml-tensorflow's Introduction

MAML-TensorFlow

An elegant and efficient implementation for ICML2017 paper: Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

Highlights

  • adopted from cbfin's official implementation with equivalent performance on mini-imagenet
  • clean, tiny code style and very easy-to-follow from comments almost every lines
  • faster and trivial improvements, eg. 0.335s per epoch comparing with 0.563s per epoch, saving up to 3.8 hours for total 60,000 training process

How-TO

  1. Download mini-Imagenet from here and extract them as :
	miniimagenet/	
	├── images	
		├── n0210891500001298.jpg  		
		├── n0287152500001298.jpg 		
		...		
	├── test.csv	
	├── val.csv	
	└── train.csv	
	└── proc_images.py
	

then replace the path by your actual path in data_generator.py:

		metatrain_folder = config.get('metatrain_folder', '/hdd1/liangqu/datasets/miniimagenet/train')
		if True:
			metaval_folder = config.get('metaval_folder', '/hdd1/liangqu/datasets/miniimagenet/test')
		else:
			metaval_folder = config.get('metaval_folder', '/hdd1/liangqu/datasets/miniimagenet/val')
  1. resize all raw images to 84x84 size by
python proc_images.py
  1. train
python main.py

Since tf.summary is time-consuming, I turn it off by default. uncomment the 2 lines to turn it on:

	# write graph to tensorboard
	# tb = tf.summary.FileWriter(os.path.join('logs', 'mini'), sess.graph)

	...
	# summ_op
	# tb.add_summary(result[1], iteration)

and then minitor training process by:

tensorboard --logdir logs
  1. test
python main.py --test

As MAML need generate 200,000 train/eval episodes before training, which usually takes up to 6~8 minutes, I use an cache file filelist.pkl to dump all these episodes for the first time and then next time the program will load from the cached file. It only takes several seconds to load from cached files.

generating episodes: 100%|█████████████████████████████████████████| 200000/200000 [04:38<00:00, 717.27it/s]

maml-tensorflow's People

Contributors

dragen1860 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

maml-tensorflow's Issues

Examples from different classes share same label.

Hello,I find that you give a label range in (0,80) to each examplein a meta_batch.But it make me wonder whether it will make the CNN confused because you give some examples in the same class different labels,also you give some examples in the different classes same label.

train iterations problem

the code for main.py train iterations is 600000 but the guide you present is 60000 ? i can not understamd

Accuracy not matching original MAML paper/repo

It looks like the repo is set up to reproduce the bottom-left-most cell in Table 1 of the the MAML paper, which has 1-shot 5-way miniimagenet test accuracy of 48.70 ± 1.84%. Is that correct? When I run the code from the official MAML repo I get a similar number, but when I run this code I'm getting around 40%. Are there some settings I need to change? I like the simplicity of this code and would like to use it as long as I can reproduce the original results.

What‘s the meanings of kshot and kquery?

Hellp,thank you for your contribution.I have some confusion about the definition of some paramaters.n-way means we sample from n classes, k-shot means we sample k examples from each class(Am I right?),but I don't understand what is "k-query"?

TF version

What was the correct version of TensorFlow?

Testing-Strategy

Hello, Thanks for the code.
I have a question regarding testing strategy. I trained model and run the with --test parameter.

During test time I can see from the output that something like that:
[support_t0, query_t0 - K]
mean: [ 0.20175 0.37785035 0.41440004 0.43647221 0.44063893 0.44136062
0.44179988 0.44246089 0.44228852 0.44238317 0.44273853] .....

Here querry results are trained model accuracies. We are using query_y to refine our prediction and computing the accuracy for 10 iteration. This approach doesn't make sense to me in a few shot learning setting. During test time, I was expecting that we will train model on K samples from classes [C1,C2, ..,Cn] and test on the other samples from these classes [C1,C2, .., Cn]. Here we are training model still in all test samples. update

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.