Comments (9)
@rwightman I did experiments by replacing LayerNorm2d (or GroupNorm1) with LayerNorm (transformer or ConvNext style) and found that it makes training unstable and also hurts the performance.
Attached are the training and validation loss curves (I did some hyper-parameter tweaking for new experiments, but training instability was kind of consistent across experiments). I used LayerNorm implementation from ane-transformers.
from ml-cvnets.
We use the relationship between GroupNorm and LayerNorm, as described in GroupNorm paper. This is also consistent with PyTorch's documentation, which also suggests that putting all channels in one group is equivalent to layer norm. We will clarify it in the documentation.
from ml-cvnets.
To be more specific GroupNorm w/ groups=1 normalizes over C, H, W. LayerNorm as used in transformers normalizes over the channel dimension only. Since PyTorch LN doesn't natively support 2d rank-4 NCHW tensors, a 'LayerNorm2d' impl (ConvNeXt, EdgeNeXt, CoaTNet, and many more) is often used that either manually calcs mean/var over C dim or permutes to NHWC and back. In either case the norm remains over just channel dim.
GroupNorm(C, groups=1, affine=False) == LayerNorm([C, H, W], elementwise_affine=False) NOT LayerNorm(C) w/ permute.
Additionaly, if the affine scale/bias is enabled, there is no way to get equivalence as groupnorm scales/shifts over C dim, while LayerNorm will apply to all of C, H, W in the case where LN == GN(groups=1).
from ml-cvnets.
Thanks for the suggestions. We will re-train MobileViTv2 with ConvNext-style layernorm and also rename the LayerNorm2D
as group norm (to be consistent with other works and implementations)
from ml-cvnets.
@sacmehta the equivalence for GN and LN as per the paper is for NCHW tensors when LN is performed over all of C, H, W (minus the affine part as mentioned). However, the LN in transformers, including when used with 2D NCHW tensors is usually over just C.
There is nothing at all wrong with what you've implemented, it may well be better, but calling a LN is a bit confusing given other uses and difference in how affine params are applied. PoolFormer is using the same as you but theirs is just called GroupNorm (w/ groups forced to 1), I called it GroupNorm1 when I used it (not sure that makes it any more clear though, heh).
There would be a few fewer flops in the LN over C only case, but unfortunately with no efficient impl for PyTorch, the permutes required can slow things down a bit. In either case I'd be curious to see if the accuracy changes.
from ml-cvnets.
@rwightman Thanks for the feedback. I will change the name of files for clarity.
I will keep you posted about LN experiments.
from ml-cvnets.
@sacmehta thanks for the update, looks like the channels-only LN is definitely not stable in this architecture.
from ml-cvnets.
I tried to implement mobilevit v2 with tensorflow 2.x, and layernorm uses layers.LayerNormalization(epsilon=1e-6). Comparing the output of each layer, I found that it is inconsistent with the output of the pytorch version of layernorm
https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization
from ml-cvnets.
Later, I changed to tfa.layers.GroupNormalization(Addons), set groups=1, and checked the output of each layer, which is consistent with the layernorm of the pytorch version;
https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
I checked the transformer code implemented by the keras team, they are using layers.LayerNormalization(epsilon=1e-6)
https://keras.io/examples/vision/image_classification_with_vision_transformer/
I use the mobilevit v2 of the tensorflow 2.x version I reimplemented, and the check output is consistent with the pytorch version. During the transfer learning, the loss does not decrease (batchsize=64)
from ml-cvnets.
Related Issues (20)
- Questions about the file bytes length of ByteFormer HOT 3
- How to solve this problem:ModuleNotFoundError: No module named 'cvnets.models.classification.' HOT 5
- Normalization Params
- 'nan' loss when training 'ByteFormer' using ImageNet
- Runtime error on single GPU Linux environment training HOT 1
- Not possible to test ByteFormer HOT 5
- Cross Attention Computation in LinearSelfAttention()
- [Feature Request] Docker container
- Segmentation model conversion size mismatch HOT 1
- Size mismatch error when loading a pretrained model HOT 1
- The license of pretrained weights
- Question: Do you have removed the support for video classification? HOT 2
- crash if different number of classess within `train/test` set HOT 6
- VIT-tiny weights and config dont match? HOT 1
- How to convert segmentation model results into VNInstanceMaskObservation? HOT 4
- ModuleNotFoundError: No module named 'main_train' HOT 2
- Using vision transformers for different image resolutions HOT 1
- ModuleNotFoundError: No module named 'main_train'
- ModuleNotFoundError: No module named 'main_train'
- AttributeError: 'NoneType' object has no attribute 'size' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ml-cvnets.