Giter Site home page Giter Site logo

mae_st's Introduction

Masked Autoencoders As Spatiotemporal Learners: A PyTorch Implementation

This is a PyTorch/GPU re-implementation of the paper Masked Autoencoders As Spatiotemporal Learners:

@Article{MaskedAutoencodersSpatiotemporal2022,
  author  = {Christoph Feichtenhofer and Haoqi Fan and Yanghao Li and Kaiming He},
  journal = {arXiv:2205.09113},
  title   = {Masked Autoencoders As Spatiotemporal Learners},
  year    = {2022},
}

Another implementation that supports AVA and SSv2 downstream evaluation is available in PySlowFast.

  • This repo is a modification on the MAE repo. Installation and preparation follow INSTALL.md.

  • This repo is based on timm==0.3.2, for which a fix is needed to work with PyTorch 1.8.1+.

Catalog

  • Visualization demo
  • Pre-trained checkpoints + fine-tuning code + testing code
  • Pre-training code

Visualization demo

Visualization of MAE output with 95% (left) and 98% (right) mask rate on the same video.

Run our interactive visualization demo using Colab notebook (no GPU needed):

Fine-tuning with pre-trained checkpoints

The following table provides the pre-trained checkpoints used in the paper, pretrained with 90% mask ratio and 1600 effective epochs, converted from the PySlowFast codebase:

ViT-Large ViT-Huge
pre-trained checkpoint on Kinetics-400 download download
md5 edf3a5 3d7f64
ViT-Large ViT-Huge
pre-trained checkpoint on Kinetics-600 download download
md5 9a9645 27495e
ViT-Large ViT-Huge
pre-trained checkpoint on Kinetics-700 download download
md5 cdbada 4c4e3c

The fine-tuning instruction is in FINETUNE.md.

Pre-training

The pre-training instruction is in PRETRAIN.md.

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

mae_st's People

Contributors

amyreese avatar facebook-github-bot avatar haooooooqi avatar r-barnes avatar weiwangmeta avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

mae_st's Issues

Encoder

Hi.
Thank you for providing amazing work. May I ask could you please share some ideas to extract the feature of the input video? The feature is the one that the encoder extracted from the visiable patches but has not include any mask token. In addition, may I ask whether the run_test is for conducting the mae reconstruction by using the pretrained weight? How the run_test be used to the datasets besides kinetics-400 eg.live_vqc. Thank you very much.

Finetuned checkpoints

Hi, I wonder if you could provide the finetuned checkpoints based on the your pretrained checkpoints just for the sake of verification? Thanks.

test.csv format

In the DATASET.md, it mentions that the format for kinetics-400 is like this:
path_to_video_1 label_1
path_to_video_2 label_2
path_to_video_3 label_3
...
path_to_video_N label_N

I am curious what the label_1 stands for and how to get the test.csv file in this form?

Epoch number and warmup epoch number

Dear authors,

In Pretrain.md (https://github.com/facebookresearch/mae_st/blob/main/PRETRAIN.md), epochs and warmup_epochs are specified as 100 and 5, respectively, which seems inconsistent with configurations in your paper. With repeat_aug=4, I should set epochs to 200 in order to train an 800ep model, right?

In Table 5, the warmup epoch is set to 120. But in Pretrain.md, it is set to 5. Which one should I follow? When epoch is set to 200, how am I supposed to set warmup_epochs?

Many thanks

'_OpNamespace' 'video_reader' object has no attribute 'probe_video_from_memory'

'_OpNamespace' 'video_reader' object has no attribute 'probe_video_from_memory'

when i run meta = io._probe_video_from_memory(video_tensor) in util.decoder.decode and look into the function in torchvision

i found torch.ops.video_reader.probe_video_from_memory(video_data) in function defined in torchvision and it is not reachable

the when i run it i got the error '_OpNamespace' 'video_reader' object has no attribute 'probe_video_from_memory'

Difference with VideoMAE

Hi All,

Thanks for the nice work! I would be interested to know the difference with the VideoMAE method. I am aware that these are concurrent related works, but it would be interested to know some of the minute methodology differences, e.g., the masking strategy.

Thanks a lot!

what is torch.fb.rendezvous.zeus responsible for?

torch.fb.rendezvous.zeus is imported in the util/misc.py, although there are no references to it; so I was able to run the code in non-distributed mode by simply commenting it out.

However, after trying to run the code in distributed mode for quite some time, I realized that zeus may be the culprit.

Since torch.fb.rendezvous.zeus doesn't seem to be open-source, could you please help clarify its functionality or just share a copy of the env variables, etc that it may be configuring?

import torch.fb.rendezvous.zeus

ModuleNotFoundError: No module named 'torch.fb'

t_patch_size issue

Hi. Thank you for your work. When working with the model I encounter a problem that the default value of t_patch_size is 4 and number of frames is 32 in the code, but the value of t_patch_size is 2 in pretrained weight. May I ask how to solve that? Thank you very much.

Dataset and actually running issues

Hi. I encounter some errors when running run_finetune.py, run_pretrain.py, run_test.py.
May I ask how I could find the proper training, testing and val.csv? Should we use the one downloaded from kinetics400 directly? If so, may I ask how to do some modification on the data? It seems that the train/test.csv should have only two columns, path and label. But i did not find path column in the downloaded files.

In addition, may I ask could you share the download links for the dataset? Thank you very much.

Lastly, may I ask what is the slowfast installation used for in this project?

Question about the provided pre-trained weights.

Dear all,

Thanks for your excellent work. I downloaded the Kinetic-400 ViT-Large pre-trained checkpoint. I can load it using my own ViT models. However, I noticed that some keys do not seem to map to anything in the model definition and I would appreciate it if you could explain what "pred_head" refers to.

odict_keys(['pred_head.transforms.0.0.norm1.weight', 'pred_head.transforms.0.0.norm1.bias', 'pred_head.transforms.0.0.attn.q.weight', 'pred_head.transforms.0.0.attn.q.bias', 'pred_head.transforms.0.0.attn.k.weight', 'pred_head.transforms.0.0.attn.k.bias', 'pred_head.transforms.0.0.attn.v.weight', 'pred_head.transforms.0.0.attn.v.bias', 'pred_head.transforms.0.0.attn.proj.weight', 'pred_head.transforms.0.0.attn.proj.bias', 'pred_head.transforms.0.0.norm2.weight', 'pred_head.transforms.0.0.norm2.bias', 'pred_head.transforms.0.0.mlp.fc1.weight', 'pred_head.transforms.0.0.mlp.fc1.bias', 'pred_head.transforms.0.0.mlp.fc2.weight', 'pred_head.transforms.0.0.mlp.fc2.bias', 'pred_head.transforms.0.1.norm1.weight', 'pred_head.transforms.0.1.norm1.bias', 'pred_head.transforms.0.1.attn.q.weight', 'pred_head.transforms.0.1.attn.q.bias', 'pred_head.transforms.0.1.attn.k.weight', 'pred_head.transforms.0.1.attn.k.bias', 'pred_head.transforms.0.1.attn.v.weight', 'pred_head.transforms.0.1.attn.v.bias', 'pred_head.transforms.0.1.attn.proj.weight', 'pred_head.transforms.0.1.attn.proj.bias', 'pred_head.transforms.0.1.norm2.weight', 'pred_head.transforms.0.1.norm2.bias', 'pred_head.transforms.0.1.mlp.fc1.weight', 'pred_head.transforms.0.1.mlp.fc1.bias', 'pred_head.transforms.0.1.mlp.fc2.weight', 'pred_head.transforms.0.1.mlp.fc2.bias', 'pred_head.transforms.0.2.norm1.weight', 'pred_head.transforms.0.2.norm1.bias', 'pred_head.transforms.0.2.attn.q.weight', 'pred_head.transforms.0.2.attn.q.bias', 'pred_head.transforms.0.2.attn.k.weight', 'pred_head.transforms.0.2.attn.k.bias', 'pred_head.transforms.0.2.attn.v.weight', 'pred_head.transforms.0.2.attn.v.bias', 'pred_head.transforms.0.2.attn.proj.weight', 'pred_head.transforms.0.2.attn.proj.bias', 'pred_head.transforms.0.2.norm2.weight', 'pred_head.transforms.0.2.norm2.bias', 'pred_head.transforms.0.2.mlp.fc1.weight', 'pred_head.transforms.0.2.mlp.fc1.bias', 'pred_head.transforms.0.2.mlp.fc2.weight', 'pred_head.transforms.0.2.mlp.fc2.bias', 'pred_head.transforms.0.3.norm1.weight', 'pred_head.transforms.0.3.norm1.bias', 'pred_head.transforms.0.3.attn.q.weight', 'pred_head.transforms.0.3.attn.q.bias', 'pred_head.transforms.0.3.attn.k.weight', 'pred_head.transforms.0.3.attn.k.bias', 'pred_head.transforms.0.3.attn.v.weight', 'pred_head.transforms.0.3.attn.v.bias', 'pred_head.transforms.0.3.attn.proj.weight', 'pred_head.transforms.0.3.attn.proj.bias', 'pred_head.transforms.0.3.norm2.weight', 'pred_head.transforms.0.3.norm2.bias', 'pred_head.transforms.0.3.mlp.fc1.weight', 'pred_head.transforms.0.3.mlp.fc1.bias', 'pred_head.transforms.0.3.mlp.fc2.weight', 'pred_head.transforms.0.3.mlp.fc2.bias', 'pred_head.transforms.0.4.weight', 'pred_head.transforms.0.4.bias', 'pred_head.projections.0.weight', 'pred_head.projections.0.bias',])

load model issue

Thank you for your amazing work.
I encounter some issues when running the run_test.py. It shows the following error message when loading the checkpoint of the model. I download the weight and does not make any modification. Could you please help me solve this problem? Thank you very much.

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
size mismatch for patch_embed.proj.weight: copying a param with shape torch.Size([1024, 3, 2, 16, 16]) from checkpoint, the shape in current model is torch.Size([1024, 3, 4, 16, 16]).

Also for the huge one.
And for the pos_embed.py, the "pos_embed" does not in checkpoint_model, so it will never enter the interpolate_pos_embed function. May I ask how to solve this size mismatch error?

Why repeated sampling does not affect the actual learning rate?

Hi! Thanks for your great work!

The repeat sampling trick increases the actual batch size (although they come from the same video), however, the actual learning rate is not scaled by the number of repeated samples.

mae_st/main_pretrain.py

Lines 280 to 283 in d752324

eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256

I wonder why?

How to start multi-gpu training?

Hi, Thank you authors for releasing the code.

I have been trying to run run_pretrain.py which calls main() function in main_pretrain.py. However, it spawns a single GPU run. The function launch_one_thread() is not called anywhere.
Is there any flag that needs to be passed for starting multi-gpu training or there needs to be some changes in the code?
Any help would be highly appreciated.

Thanks

403 error

downloading pretrained models got 403 Forbidden error.

Finetune

Hi,

Thank you so much for sharing the code of this amazing work! Is it possible to also share finetuning for image classification? Current main_finetune.py seems to support Kinetics finetuning only. The parameter "data_path" directs to imagenet directory but is not actually used. Thank you in advance!

Best

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.