Giter Site home page Giter Site logo

oc-fewshot-public's Introduction

oc-fewshot-public

Code for our ICLR 2021 paper Wandering Within a World: Online Contextualized Few-Shot Learning [arxiv]

RoamingRooms Dataset

Although our code base is MIT licensed, the RoamingRooms dataset is not since it is derived from the Matterport3D dataset.

To download the RoamingRooms dataset, you first need to sign the agreement for a non-commercial license here.

Then, you need to submit a request here. We will manually approve your request afterwards.

For inquiries, please email: [email protected]

The whole dataset is around 60 GB. It has 1.2M video frames with 7k unique object instance classes. Please refer to our paper for more statistics of the dataset.

System Requirements

Our code is tested on Ubuntu 18.04 with GPU capability. We provide docker files for reproducible environments. We recommend at least 20GB CPU memory and 11GB GPU memory. 2-4 GPUs are required for multi-GPU experiments. Our code is based on TensorFlow 2.

Installation Using Docker (Recommended)

  1. Install protoc from here.

  2. Run make to build proto buffer configuration files.

  3. Install docker and nvidia-docker.

  4. Build the docker container using ./build_docker.sh.

  5. Modify the environment paths. You need to change DATA_DIR and OURPUT_DIR in setup_environ.sh. DATA_DIR is the main folder where datasets are placed and OUTPUT_DIR is the main folder where training models are saved.

Installation Using Conda

  1. Install protoc from here.

  2. Run make to build proto buffer configuration files.

  3. Modify the environment paths. You need to change DATA_DIR and OURPUT_DIR in setup_environ.sh. DATA_DIR is the main folder where datasets are placed and OUTPUT_DIR is the main folder where training models are saved.

  4. Create a conda environment:

conda create -n oc-fewshot python=3.6
conda activate oc-fewshot
conda install pip
  1. Install CUDA 10.1

  2. Install OpenMPI 4.0.0

  3. Install NCCL 2.6.4 for CUDA 10.1

  4. Modify installation paths in install.sh

  5. Run install.sh

Setup Datasets

  1. To set up the Omniglot dataset, run script/download_omniglot.sh. This script will download the Omniglot dataset to DATA_DIR.

  2. To set up the Uppsala texture dataset (for spatiotemporal cue experiments), run script/download_uppsala.sh. This script will download the Uppsala texture dataset to DATA_DIR.

RoamingOmniglot Experiments

To run training on your own, use the following command.

./run_docker.sh {GPU_ID} python -m fewshot.experiments.oc_fewshot \
  --config {MODEL_CONFIG_PROTOTXT} \
  --data {EPISODE_CONFIG_PROTOTXT} \
  --env configs/environ/roaming-omniglot-docker.prototxt \
  --tag {TAG} \
  [--eval]
  • MODEL_CONFIG_PROTOTXT can be found in configs/models.
  • EPISODE_CONIFG_PROTOTXT can be found in configs/episodes.
  • TAG is the name of the saved checkpoint folder.
  • When the model finishes training, add the --eval flag to evaluate.

For example, to train CPM on the semisupervised benchmark:

./run_docker.sh 0 python -m fewshot.experiments.oc_fewshot \
  --config configs/models/roaming-omniglot/cpm.prototxt \
  --data configs/episodes/roaming-omniglot/roaming-omniglot-150-ssl.prototxt \
  --env configs/environ/roaming-omniglot-docker.prototxt \
  --tag roaming-omniglot-ssl-cpm

All of our code is tested using GTX 1080 Ti with 11GB GPU memory. Note that the above command uses a single GPU. Our original experiments in the paper is performed using two GPUs, with twice the batch size and doubled learning rate. To run that setting, use the following command:

./run_docker_hvd_01.sh python -m fewshot.experiments.oc_fewshot_hvd \
  --config {MODEL_CONFIG_PROTOTXT} \
  --data {EPISODE_CONFIG_PROTOTXT} \
  --env configs/environ/roaming-omniglot-docker.prototxt \
  --tag {TAG}

RoamingRooms Experiments

Below we include command to run experiments on RoamingRooms. Our original experiments in the paper is performed using four GPUs, with batch size to be 8. To run that setting, use the following command:

./run_docker_hvd_0123.sh python -m fewshot.experiments.oc_fewshot_hvd \
  --config {MODEL_CONFIG_PROTOTXT} \
  --data {EPISODE_CONFIG_PROTOTXT} \
  --env configs/environ/roaming-rooms-docker.prototxt \
  --tag {TAG}

When evaluate, use --eval --usebest to pick the checkpoint with the highest validation performance.

Results

Table 1: RoamingOmniglot Results (Supervised)

Method AP 1-shot Acc. 3-shot Acc. Checkpoint
LSTM 64.34 61.00 ± 0.22 81.85 ± 0.21 link
DNC 81.30 78.87 ± 0.19 91.01 ± 0.15 link
OML-U 77.38 70.98 ± 0.21 89.13 ± 0.16 link
OML-U++ 86.85 88.43 ± 0.14 92.07 ± 0.14 link
Online MatchingNet 88.69 84.82 ± 0.15 95.55 ± 0.11 link
Online IMP 90.15 85.74 ± 0.15 96.66 ± 0.09 link
Online ProtoNet 90.49 85.68 ± 0.15 96.95 ± 0.09 link
CPM (Ours) 94.17 91.99 ± 0.11 97.74 ± 0.08 link

Table 2: RoamingOmniglot Results (Semi-supervised)

Method AP 1-shot Acc. 3-shot Acc. Checkpoint
LSTM 54.34 68.30 ± 0.20 76.38 ± 0.49 link
DNC 81.37 88.56 ± 0.12 93.81 ± 0.26 link
OML-U 66.70 74.65 ± 0.19 90.81 ± 0.34 link
OML-U++ 81.39 89.07 ± 0.19 89.40 ± 0.18 link
Online MatchingNet 84.39 88.77 ± 0.13 97.28 ± 0.17 link
Online IMP 81.62 88.68 ± 0.13 97.09 ± 0.19 link
Online ProtoNet 84.61 88.71 ± 0.13 97.61 ± 0.17 link
CPM (Ours) 90.42 93.18 ± 0.16 97.89 ± 0.15 link

Table 3: RoamingRooms Results (Supervised)

Method AP 1-shot Acc. 3-shot Acc. Checkpoint
LSTM 45.67 59.90 ± 0.40 61.85 ± 0.45 link
DNC 80.86 82.15 ± 0.32 87.30 ± 0.30 link
OML-U 76.27 73.91 ± 0.37 83.99 ± 0.33 link
OML-U++ 88.03 88.32 ± 0.27 89.61 ± 0.29 link
Online MatchingNet 85.91 82.82 ± 0.32 89.99 ± 0.26 link
Online IMP 87.33 85.28 ± 0.31 90.83 ± 0.25 link
Online ProtoNet 86.01 84.89 ± 0.31 89.58 ± 0.28 link
CPM (Ours) 89.14 88.39 ± 0.27 91.31 ± 0.26 link

Table 4: RoamingRooms Results (Semi-supervised)

Method AP 1-shot Acc. 3-shot Acc. Checkpoint
LSTM 33.32 52.71 ± 0.38 55.83 ± 0.76 link
DNC 73.49 80.27 ± 0.33 87.87 ± 0.49 link
OML-U 63.40 70.67 ± 0.38 85.25 ± 0.56 link
OML-U++ 81.90 84.79 ± 0.31 89.80 ± 0.47 link
Online MatchingNet 78.99 80.08 ± 0.34 92.43 ± 0.41 link
Online IMP 75.36 84.57 ± 0.31 91.17 ± 0.43 link
Online ProtoNet 76.36 80.67 ± 0.34 88.83 ± 0.49 link
CPM (Ours) 84.12 86.17 ± 0.30 91.16 ± 0.44 link

Table 5: RoamingImageNet Results (Supervised)

Method AP 1-shot Acc. 3-shot Acc. Checkpoint
LSTM 7.73 11.60 ± 0.12 43.93 ± 0.27 link
LSTM* 22.54 28.14 ± 0.20 52.07 ± 0.27 link
DNC 7.20 10.55 ± 0.11 42.22 ± 0.27 link
DNC* 26.80 33.45 ± 0.19 55.78 ± 0.27 link
OML-U 21.89 15.06 ± 0.14 52.52 ± 0.27 link
OML-U Cos 10.87 24.45 ± 0.18 30.89 ± 0.24 link
Online MatchingNet 13.05 20.61 ± 0.15 38.73 ± 0.24 link
Online IMP 14.25 22.92 ± 0.16 41.01 ± 0.25 link
Online ProtoNet 15.51 22.95 ± 0.17 44.98 ± 0.25 link
Online ProtoNet* 23.10 32.82 ± 0.19 49.98 ± 0.25 link
CPM (Ours) 34.43 40.40 ± 0.21 60.29 ± 0.26 link

* denotes using pretrained CNN.

Table 6: RoamingImageNet Results (Semi-supervised)

Method AP 1-shot Acc. 3-shot Acc. Checkpoint
LSTM 4.03 22.53 ± 0.18 41.34 ± 0.55 link
LSTM* 13.50 30.02 ± 0.20 46.95 ± 0.56 link
DNC 3.66 22.37 ± 0.18 37.83 ± 0.54 link
DNC* 16.50 39.53 ± 0.19 54.10 ± 0.54 link
OML-U 10.16 22.74 ± 0.17 55.81 ± 0.55 link
OML-U Cos 5.65 23.37 ± 0.16 32.79 ± 0.50 link
Online MatchingNet 9.32 25.96 ± 0.16 55.32 ± 0.51 link
Online IMP 4.55 20.70 ± 0.15 51.23 ± 0.53 link
Online ProtoNet 7.10 26.87 ± 0.16 42.40 ± 0.52 link
Online ProtoNet* 15.76 36.69 ± 0.18 55.47 ± 0.53 link
CPM (Ours) 24.75 44.58 ± 0.21 58.72 ± 0.53 link

* denotes using pretrained CNN.

To-Do

  • Add a data iterator based on PyTorch (contribution welcome).

Citation

If you use our code, please consider cite the following:

  • Mengye Ren, Michael L. Iuzzolino, Michael C. Mozer and Richard S. Zemel. Wandering Within a World: Online Contextualized Few-Shot Learning. In ICLR, 2021.
@inproceedings{ren21ocfewshot,
  author    = {Mengye Ren and
               Michael L. Iuzzolino and
               Michael C. Mozer and
               Richard S. Zemel},
  title     = {Wandering Within a World: Online Contextualized Few-Shot Learning},
  booktitle = {9th International Conference on Learning Representations, {ICLR}},
  year      = {2021}
}

oc-fewshot-public's People

Contributors

renmengye avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

oc-fewshot-public's Issues

Trouble replicating results on Roaming-Imagenet for pretrained models

Hi!

I'm not able to replicate the performance for the pretrained online-protonet on the Roaming-Imagenet dataset. With the provided checkpoint I get the reported metrics, but if I pretrain the model I'm getting ~21AP on testing and ~22AP on validation.

To pretrain I use the following command:

python -m fewshot.experiments.pretrain --config configs/models/roaming-imagenet/pretrain.prototxt --env configs/environ/roaming-imagenet-docker.prototxt --seed 0 --tag pretrained

On the pretrain.py file I modified this line: dataset = get_data(env_config) and put this instead: dataset = get_data_fs(env_config, load_train=True), because the method get_data does not exist on the fewshot.experiments.utils file.

Finally, for finetunning/evaluating the pretrained model I run:
python -m fewshot.experiments.oc_fewshot --config configs/models/roaming-imagenet/online-protonet.prototxt --data configs/episodes/roaming-imagenet/roaming-imagenet-150.prototxt --env configs/environ/roaming-imagenet-docker.prototxt --tag the_tag --pretrain results/oc-fewshot/tiered-imagenet/pretrained_model/weights-40000

I hope that you can helps me with this and thanks for the code!

Bug on hierarchical episode sampler

Hi, I've been looking at the code some time and I think i found a bug in the hierarchical episode sampler, specifically on

# Line 169 currently is 
for c in range(min(episode_classes.max(), len(hmap))):
# But should be
for c in range(min(stage.max() + 1, len(hmap))):
# Because each stage is mapped to one of the hierarchy classes

And

# On line 171 it currently is
for c, s in zip(episode_classes, stage):
    # Magic number is 
    results.append(self.hierarchy_dict[hmap[c % len(hmap)]][s])
# The code above chooses the the environment based on the stage relative class number and the 
# class mapped to based on the stage

# But should be reversed to
for c, s in zip(episode_classes, stage):
    # Number of classes of previous stages that belong to the 
    # same class hierarchy as current stage
    mask = ((stage < s) & ((stage % len(hmap)) == (s % len(hmap))))
    prev_samples = np.stack((episode_classes, stage), axis=1)[mask]
    offset = np.unique(prev_samples, axis=0).shape[0]

    # Choose the hierarchy based on the stage and the class on the stage relative class number
    results.append(self.hierarchy_dict[hmap[s % len(hmap)]][c + offset])

Hope I'm not missing something and I'm currently training the models with these modifications, i'll post the results when it get them!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.