Giter Site home page Giter Site logo

barlow-twins-hsic's Issues

The feature normalization is necessary?

return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)

Hi, your code is very helpful and I want to firstly appreciate the code share.

I have a question on whether this feature normalization is necessary (to make the cifar10 performance to about 92% accuracy).

The original Barlow Twins does not contain this step. On the other hand, they rather define all linear layers in the projector with no bias.

Question re: reproducing Fig 2 from the paper

Hello --

I'm interested in trying to reproduce the Barlow Twins curve from Fig 2 in the paper.

I'm running:

python main.py --lmbda 0.0078125 --corr_zero --batch_size 128 --feature_dim 128 --dataset cifar10

and getting:

Test Epoch: [5/1000] Acc@1:47.33% Acc@5:92.49%
Test Epoch: [10/1000] Acc@1:53.80% Acc@5:94.87%
Test Epoch: [15/1000] Acc@1:58.44% Acc@5:96.68%
Test Epoch: [20/1000] Acc@1:63.11% Acc@5:96.86%
Test Epoch: [25/1000] Acc@1:65.55% Acc@5:97.33%
Test Epoch: [30/1000] Acc@1:66.59% Acc@5:97.61%
Test Epoch: [35/1000] Acc@1:68.85% Acc@5:97.87%
Test Epoch: [40/1000] Acc@1:69.17% Acc@5:97.75%
Test Epoch: [45/1000] Acc@1:71.24% Acc@5:98.16%
Test Epoch: [50/1000] Acc@1:72.38% Acc@5:98.26%

In Fig 2, it looks like accuracy after 50 epochs should be ~ 79%, but I'm only getting to ~72%.

Any ideas why there might be a gap? Perhaps the accuracies reported in Fig 2 are from training a linear classifier (eg, in linear.py) rather than using the weighted KNN in main.py:train?

Thanks!

Speed Up Model by Using cudnn benchmark

Not really an issue, but adding the following made training significantly faster (+29% on Titan X Pascal)

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

I thought I would mention it as others might benefit from this as well.

Thanks for making your code available on GitHub!

Inaccuracy in the cross correlation for small batch sizes (<32)

Hey,

first, thanks for sharing your research and code!

TL;DR

Your code uses torch.std which uses Bessel's correction by default, therefore inhibiting that the values on the diagonal reach 1.

While working with it, I noticed some small inaccuracy for in the calculation of the cross-correlation matrix.

Opposed to original implementation, which uses BatchNorm1d you implemented the normalization with:

# normalize the representations along the batch dimension
out_1_norm = (out_1 - out_1.mean(dim=0)) / out_1.std(dim=0)
out_2_norm = (out_2 - out_2.mean(dim=0)) / out_2.std(dim=0)

I implemented a small test with two identical vectors coming from the projection head and was therefore expecting straight ones on the diagonal. But as you can see from my attached code, for a batch size of < 32 (here 8), the values on the diagonal can't get bigger than 0.75. I found that torch.std uses Bessel's correction by default. When this flag is set to false, the numbers match with the original implementation.

I think there is no practical difference for batch sizes > 32, which is also the smallest batch size you presented in your paper, I think.

import torch
from torch import nn

batch_size = 4
size_z = 128

torch.manual_seed(1234)
z1 = torch.randn(batch_size, size_z)
z2 = z1.clone()

# your implementation
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size

print( cross_corr[0:5,0:5] )

# tensor([[ 0.7500, -0.1065, -0.0837,  0.0630,  0.3664],
#        [-0.1065,  0.7500, -0.2283, -0.3708, -0.5607],
#        [-0.0837, -0.2283,  0.7500, -0.5013, -0.2554],
#        [ 0.0630, -0.3708, -0.5013,  0.7500,  0.6334],
#        [ 0.3664, -0.5607, -0.2554,  0.6334,  0.7500]])

# original implementation
bn = nn.BatchNorm1d(size_z, affine=False)
z1_norm = bn(z1)
z2_norm = bn(z2)
cross_corr = z1_norm.T @ z2_norm / batch_size

print( cross_corr[0:5,0:5] )

# tensor([[ 1.0000, -0.1420, -0.1116,  0.0840,  0.4885],
#        [-0.1420,  1.0000, -0.3043, -0.4944, -0.7476],
#        [-0.1116, -0.3043,  1.0000, -0.6683, -0.3405],
#        [ 0.0840, -0.4944, -0.6683,  1.0000,  0.8445],
#        [ 0.4885, -0.7476, -0.3405,  0.8445,  1.0000]])

# corrected code (without Bessel’s correction for the calculation of the standard deviation)
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0, unbiased=False)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0, unbiased=False)
cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size

print( cross_corr[0:5,0:5] )

# tensor([[ 1.0000, -0.1421, -0.1116,  0.0840,  0.4885],
#         [-0.1421,  1.0000, -0.3043, -0.4944, -0.7476],
#         [-0.1116, -0.3043,  1.0000, -0.6683, -0.3405],
#         [ 0.0840, -0.4944, -0.6683,  1.0000,  0.8446],
#         [ 0.4885, -0.7476, -0.3405,  0.8446,  1.0000]])

Result about tiny-imagenet

hi, have you run the model with the tiny-imagenet? can you tell the result on the tiny-imagenet as I try to run to the model on tiny-imagenet but it seems too low?

Question Regarding Tranform

The implementation is simple and easy to use. Thank you for that. I have one doubt,

Given a mini batch with input x of size BxCxHxW

we apply transformations to get
y1 = self.transform(x)
y2 = self.transform(x)

So is this a batch transformation or image wise transformation

Because as per the paper "More specifically, it
produces two distorted views for all images of a batch X
sampled from a dataset" there are two distorted views only i interpret it as for one distorted view we apply the same transformation for the images in a batch

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.