Comments (4)
Hey @danifranco
Thank you for taking interest in this work!
What I remember from our tutorial is that we tweaked the implementation in such a way that the decoder spits out the entire image all at once (hence the image_size, image_size, 3
) that you have pointed out. This makes it easier for us to resuse the patchification layer and gather the patches from the output image.
def calculate_loss(self, images, test=False):
# Augment the input images.
if test:
augmeneted_images = self.test_augmentation_model(images)
else:
augmeneted_images = self.train_augmentation_model(images)
# Patch the augmented images.
patches = self.patch_layer(augmeneted_images)
# Encode the patches.
(
unmasked_embeddings,
masked_embeddings,
unmasked_positions,
mask_indices,
unmask_indices,
) = self.patch_encoder(patches)
# Pass the unmaksed patche to the encoder.
encoder_outputs = self.encoder(unmasked_embeddings)
# Create the decoder inputs.
encoder_outputs = encoder_outputs + unmasked_positions
decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
# Decode the inputs.
decoder_outputs = self.decoder(decoder_inputs) # <----- outputs an image
decoder_patches = self.patch_layer(decoder_outputs) # <----- patchifies the image and
loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1) # <----- gathers the patches for loss (this can be optimized a lot!)
loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)
# Compute the total loss.
total_loss = self.compiled_loss(loss_patch, loss_output)
return total_loss, loss_patch, loss_output
You are right here. We cannot scale this model up and it indeed is taking a hit on the parameter count. Please feel free to raise a PR if you wanted to solve this issue with better gathering and slicing techniques inside the training loop.
Please feel free to reach out if there are any more concerns.
from mae-scalable-vision-learners.
Thank you @ariG23498 for your great work. I have encountered the same problems, is the following code correct?
# -----------------ori-----
# x = layers.Flatten()(x)
# pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
# outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
# -----------------modified-----
pre_final = layers.Dense(units=patch_size * patch_size * 3, activation="sigmoid")(x)
outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
from mae-scalable-vision-learners.
Thank you @ariG23498 for your great work. I have encountered the same problems, is the following code correct?
# -----------------ori----- # x = layers.Flatten()(x) # pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x) # outputs = layers.Reshape((image_size, image_size, 3))(pre_final) # -----------------modified----- pre_final = layers.Dense(units=patch_size * patch_size * 3, activation="sigmoid")(x) outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
It is correct, but as far as I remember (I might be wrong here as I had trained it quite some time ago), removing the Flatten()
gave me poor results. The problem was with the dataset size and the number of epocs I guess.
You could probably give this a try and report the results in this issue. That would be a great contribution!
from mae-scalable-vision-learners.
Closing this for inactivity. Please feel free to open it if the query is not resolved.
from mae-scalable-vision-learners.
Related Issues (9)
- Excellent work (`mae.ipynb`)! HOT 7
- Could you provide a one-dimensional MAE implementation in pytorch? HOT 2
- Unshuffle the patches? HOT 2
- Could you also share the weight of the pretrained decoder? HOT 1
- Final layer in decoder is incorrect HOT 2
- Save decoder and encoder HOT 1
- Hyperparameters HOT 1
- Issue with the plotting utility `show_masked_image` HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mae-scalable-vision-learners.