Giter Site home page Giter Site logo

arig23498 / mae-scalable-vision-learners Goto Github PK

View Code? Open in Web Editor NEW
72.0 5.0 15.0 39.59 MB

A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Home Page: https://keras.io/examples/vision/masked_image_modeling/

License: MIT License

Jupyter Notebook 100.00%
tensorflow2 keras self-supervised-learning autoencoder masked-image-modeling

mae-scalable-vision-learners's Introduction

Masked Autoencoders Are Scalable Vision Learners

Open In Colab

A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementation of the proposed method is available in mae-pretraining.ipynb notebook. It includes evaluation with linear probing as well. Furthermore, the notebook can be fully executed on Google Colab. Our main objective is to present the core idea of the proposed method in a minimal and readable manner. We have also prepared a blog for getting started with Masked Autoencoder easily.


With just 100 epochs of pre-training and a fairly lightweight and asymmetric Autoencoder architecture we achieve 49.33%% accuracy with linear probing on the CIFAR-10 dataset. Our training logs and encoder weights are released in Weights and Logs. For comparison, we took the encoder architecture and trained it from scratch (refer to regular-classification.ipynb) in a fully supervised manner. This gave us ~76% test top-1 accuracy.

We note that with further hyperparameter tuning and more epochs of pre-training, we can achieve a better performance with linear-probing. Below we present some more results:

Config Masking
proportion
LP
performance
Encoder weights
& logs
Encoder & decoder layers: 3 & 1
Batch size: 256
0.6 44.25% Link
Do 0.75 46.84% Link
Encoder & decoder layers: 6 & 2
Batch size: 256
0.75 48.16% Link
Encoder & decoder layers: 9 & 3
Batch size: 256
Weight deacy: 1e-5
0.75 49.33% Link

LP denotes linear-probing. Config is mostly based on what we define in the hyperparameters section of this notebook: mae-pretraining.ipynb.

Notes

Acknowledgements

References

[1] Masked Autoencoders Are Scalable Vision Learners; He et al.; arXiv 2021; https://arxiv.org/abs/2111.06377.

mae-scalable-vision-learners's People

Contributors

arig23498 avatar sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

mae-scalable-vision-learners's Issues

Issue with the plotting utility `show_masked_image`

Should be:

def show_masked_image(self, patches):
        # Utility function that helps visualize maksed images.
        _, unmask_indices = self.get_random_indices()
        unmasked_patches = tf.gather(patches, unmask_indices, axis=1, batch_dims=1)

        # Necessary for plotting.
        ids = tf.argsort(unmask_indices)
        sorted_unmask_indices = tf.sort(unmask_indices)
        unmasked_patches = tf.gather(unmasked_patches, ids, batch_dims=1)

        # Select a random index for visualization.
        idx = np.random.choice(len(sorted_unmask_indices))
        print(f"Index selected: {idx}.")

        n = int(np.sqrt(NUM_PATCHES))
        unmask_index = sorted_unmask_indices[idx]
        unmasked_patch = unmasked_patches[idx]

        plt.figure(figsize=(4, 4))

        count = 0
        for i in range(NUM_PATCHES):
            ax = plt.subplot(n, n, i + 1)

            if count < unmask_index.shape[0] and unmask_index[count].numpy() == i:
                patch = unmasked_patch[count]
                patch_img = tf.reshape(patch, (PATCH_SIZE, PATCH_SIZE, 3))
                plt.imshow(patch_img)
                plt.axis("off")
                count = count + 1
            else:
                patch_img = tf.zeros((PATCH_SIZE, PATCH_SIZE, 3))
                plt.imshow(patch_img)
                plt.axis("off")
        plt.show()

        # Return the random index to validate the image outside the method.
        return idx

Final layer in decoder is incorrect

Excuse me if this is my mistake.

The final layer of the decoder, pre_final, flattens all of the patches into a 1D vector and then uses a Dense layer that maps the features to the entire image size.

This differs from the official PyTorch implementation (https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L58) where for each token, the key is mapped to the patch size.

This makes a huge difference in the efficiency of the network, because this pre_final Dense layer is enormous.

Hyperparameters

The Hyperparameters:

  • Dataset: CIFAR10
  • Masking: 60%
  • ENC DIM - 128
  • DEC DIM - 64
  • ENC HEAD - 2
  • DEC HEAD - 2
  • ENC LAYER - 3
  • ENC LAYER - 4
  • Resize - 48

Decoder parameters become huge when scaling up the problem

Hello!

First thing I wanna thank you for this piece of code! Amazing work!

I have a question regarding the decoder implementation. In the last part the layers are defined as follows:

x = layers.Flatten()(x)
pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
outputs = layers.Reshape((image_size, image_size, 3))(pre_final)

This does not correspond to the official implementation where they create a Linear layer as follows:

self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch

Here is the link to the exac line.

The problem is that when you try to scale your implementation up, to apply, for instance, images of 224x224 as in ImageNet and configuring the same parameters as they used (here one of the models they used, the Base one), this Dense layer becomes huge in parameters. It is even bigger than the encoder... You can see a model.summary() here that I made based in your code:

____________________________________________________________________
 Layer (type)                     Output Shape             Param #   
====================================================================
 patches (Patches)                 multiple                0                      
                                                                                                                                                      
 patch_encoder (PatchEncoder)      multiple                246784                 
                                                                                                                                                      
 mae_encoder (Functional)         (None, None, 768)        85056000               
                                                                                                                                                      
 mae_decoder (Functional)         (None, 128, 128, 1)      554104320              
                                                                                                                                                      
=====================================================================
Total params: 639,407,104
Trainable params: 639,407,104

As a tip also I tried the notebook their provide in that official implementation and if print the model parameters the decoder has not the 10% of the parameters of the encoder, as they state in the manuscript. However, has a more reasonable parameter balance between encoder/decoder. See here an output of their model so you can check what should be:

from torchsummary import summary
summary(model_mae, (3, 224, 224))

OUTPUT:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 768, 14, 14]         590,592
        PatchEmbed-2             [-1, 196, 768]               0
         LayerNorm-3              [-1, 50, 768]           1,536
            Linear-4             [-1, 50, 2304]       1,771,776
           Dropout-5           [-1, 12, 50, 50]               0
            Linear-6              [-1, 50, 768]         590,592
           Dropout-7              [-1, 50, 768]               0
         Attention-8              [-1, 50, 768]               0
          Identity-9              [-1, 50, 768]               0
. . .
           Block-324             [-1, 197, 512]               0
       LayerNorm-325             [-1, 197, 512]           1,024
          Linear-326             [-1, 197, 768]         393,984
================================================================
Total params: 111,654,400
Trainable params: 111,654,400
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 296.82
Params size (MB): 425.93
Estimated Total Size (MB): 723.33
----------------------------------------------------------------

Unshuffle the patches?

Your code helps me a lot! However, I still have some questions. In the paper, the authors say they unshuffle the full list before applying the deocder. In the MaskedAutoencoder class of your implementation,
decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
no unshuffling is used. I wonder if you can tell me the purpose of doing so? Thanks a lot!

Could you also share the weight of the pretrained decoder?

Hi,

Thanks for your excellent implementation! I found that you have shared the weights of the encoder, but if we want to replicate the reconstruction, the pretrained decoder is still needed. So, could you also share the weight of the pretrained decoder?

Best Regards,
Hongxin

Excellent work (`mae.ipynb`)!

@ariG23498 this is fantastic stuff. Super clean, readable, and coherent with the original implementation. A couple of suggestions that would likely make things even better:

  • Since you have already implemented masking visualization utilities how about making them part of the PatchEncoder itself? That way you could let it accept a test image, apply random masking, and plot it just like the way you are doing in the earlier cells. This way I believe the notebook will be cleaner.
  • AdamW (tfa.optimizers.adamw) is a better choice when it comes to training Transformer-based models.
  • Are we taking the loss on the correct component? I remember you mentioning it being dealt with differently.

After these points are addressed I will take a crack at porting the training loop to TPUs along with other performance monitoring callbacks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.