Giter Site home page Giter Site logo

Comments (4)

ariG23498 avatar ariG23498 commented on July 16, 2024

Hey @danifranco

Thank you for taking interest in this work!

What I remember from our tutorial is that we tweaked the implementation in such a way that the decoder spits out the entire image all at once (hence the image_size, image_size, 3) that you have pointed out. This makes it easier for us to resuse the patchification layer and gather the patches from the output image.

def calculate_loss(self, images, test=False):
    # Augment the input images.
    if test:
        augmeneted_images = self.test_augmentation_model(images)
    else:
        augmeneted_images = self.train_augmentation_model(images)

    # Patch the augmented images.
    patches = self.patch_layer(augmeneted_images)

    # Encode the patches.
    (
        unmasked_embeddings,
        masked_embeddings,
        unmasked_positions,
        mask_indices,
        unmask_indices,
    ) = self.patch_encoder(patches)

    # Pass the unmaksed patche to the encoder.
    encoder_outputs = self.encoder(unmasked_embeddings)

    # Create the decoder inputs.
    encoder_outputs = encoder_outputs + unmasked_positions
    decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)

    # Decode the inputs.
    decoder_outputs = self.decoder(decoder_inputs) # <----- outputs an image
    decoder_patches = self.patch_layer(decoder_outputs) # <----- patchifies the image and 

    loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1) # <----- gathers the patches for loss (this can be optimized a lot!)
    loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)

    # Compute the total loss.
    total_loss = self.compiled_loss(loss_patch, loss_output)

    return total_loss, loss_patch, loss_output

You are right here. We cannot scale this model up and it indeed is taking a hit on the parameter count. Please feel free to raise a PR if you wanted to solve this issue with better gathering and slicing techniques inside the training loop.

Please feel free to reach out if there are any more concerns.

from mae-scalable-vision-learners.

wolo-wolo avatar wolo-wolo commented on July 16, 2024

Thank you @ariG23498 for your great work. I have encountered the same problems, is the following code correct?

    # -----------------ori-----
    # x = layers.Flatten()(x)
    # pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
    # outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
    # -----------------modified-----
    pre_final = layers.Dense(units=patch_size * patch_size * 3, activation="sigmoid")(x)
    outputs = layers.Reshape((image_size, image_size, 3))(pre_final)

from mae-scalable-vision-learners.

ariG23498 avatar ariG23498 commented on July 16, 2024

Thank you @ariG23498 for your great work. I have encountered the same problems, is the following code correct?

    # -----------------ori-----
    # x = layers.Flatten()(x)
    # pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
    # outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
    # -----------------modified-----
    pre_final = layers.Dense(units=patch_size * patch_size * 3, activation="sigmoid")(x)
    outputs = layers.Reshape((image_size, image_size, 3))(pre_final)

It is correct, but as far as I remember (I might be wrong here as I had trained it quite some time ago), removing the Flatten() gave me poor results. The problem was with the dataset size and the number of epocs I guess.

You could probably give this a try and report the results in this issue. That would be a great contribution!

from mae-scalable-vision-learners.

ariG23498 avatar ariG23498 commented on July 16, 2024

Closing this for inactivity. Please feel free to open it if the query is not resolved.

from mae-scalable-vision-learners.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.