Giter Site home page Giter Site logo

quanlin-wu / dmae Goto Github PK

View Code? Open in Web Editor NEW
54.0 54.0 4.0 2.11 MB

Denoising Masked Autoencoders Help Robust Classification.

Home Page: https://arxiv.org/abs/2210.06983

License: Other

Python 100.00%
pre-training robustness-certification self-supervised-learning transformer

dmae's Issues

certify error

Hi,

I tried to reproduce the certify results on CIFAR-10 and got the following error:

Traceback (most recent call last):
  File "certify_cifar10.py", line 204, in <module>
    main(args)
  File "certify_cifar10.py", line 185, in main
    test_stats = certify_evaluate_dist(data_loader_val, smoothed_classifier, device, threshold, args.num)
  File "/home/yuchongy1/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
Traceback (most recent call last):
  File "certify_cifar10.py", line 204, in <module>
    main(args)
  File "certify_cifar10.py", line 185, in main
    test_stats = certify_evaluate_dist(data_loader_val, smoothed_classifier, device, threshold, args.num)
  File "/home/yuchongy1/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/data/gpfs/projects/punim1723/dmae/engine_finetune.py", line 298, in certify_evaluate_dist
    return func(*args, **kwargs)
  File "/data/gpfs/projects/punim1723/dmae/engine_finetune.py", line 298, in certify_evaluate_dist
    output, radius = model.certify(images, 100, num, 0.001, 1000, target.item())
  File "/data/gpfs/projects/punim1723/dmae/util/smooth.py", line 40, in certify
    output, radius = model.certify(images, 100, num, 0.001, 1000, target.item())
  File "/data/gpfs/projects/punim1723/dmae/util/smooth.py", line 40, in certify
    counts_selection = self._sample_noise(x, n0, batch_size)
  File "/data/gpfs/projects/punim1723/dmae/util/smooth.py", line 96, in _sample_noise
    predictions = self.base_classifier(noisy).argmax(1)
  
RuntimeError: Given normalized_shape=[768], expected input with shape [*, 768], but got input of size[100]
RuntimeError: Given normalized_shape=[768], expected input with shape [*, 768], but got input of size[100]

I follow the instructions mentioned in FINETUNE.MD.

I printed the shape of "noisy", which is [100, 3, 224, 224].

Could you please help?

cannot download checkpoint, and cannot load pre-trained DMAE model properly

I tried running # download checkpoint if not exist !wget -nc https://tlgw4g.dm.files.1drv.com/y4meOrodqzrNG2JjHw4cfoF8nZoSvJM9g7hk8G3q58--mBgbWyCvuWP8x91Z6dJxmd4MZBpM4UoX7tNqPr0XYmHKMtXX37ctxAQObsVJ298ldzyY5wS5T3DiliR2T-gSr4XVbG6w76nGSCG6PAws_y6hYLLtaaZlx9QrezOQonTvR2RagiUYt5GyoCkq7JuGyF0T2e7X7HlkfU_47M8gNpCGw. I got an issue.
Screenshot from 2024-02-08 16-38-10

Then I downloaded the pre-trained manually and changed the path. It got an error. What should I do to fixed?
Screenshot from 2024-02-08 16-40-53

Checkpoint Loading

The CIFAR finetuning code (finetune_cifar10.py) does not line up with the released checkpoint (dmae_base_sigma_0.25_mask_0.75_1100e.pth).

Firstly, using the provided

checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.resume)
checkpoint_model = checkpoint['model']

results in the error that the key 'model' is not found. This is because the checkpoint is actually a dictionary of the layers.

Secondly, changing the code to:

checkpoint_model = checkpoint

results in the error:

_IncompatibleKeys(missing_keys=['fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'decoder_pos_embed', 'norm.weight', 'norm.bias', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight', 'decoder_blocks.0.norm1.bias', 'decoder_blocks.0.attn.qkv.weight', 'decoder_blocks.0.attn.qkv.bias', 'decoder_blocks.0.attn.proj.weight', 'decoder_blocks.0.attn.proj.bias', 'decoder_blocks.0.norm2.weight', 'decoder_blocks.0.norm2.bias', 'decoder_blocks.0.mlp.fc1.weight', 'decoder_blocks.0.mlp.fc1.bias', 'decoder_blocks.0.mlp.fc2.weight', 'decoder_blocks.0.mlp.fc2.bias', 'decoder_blocks.1.norm1.weight', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.1.attn.qkv.weight', 'decoder_blocks.1.attn.qkv.bias', 'decoder_blocks.1.attn.proj.weight', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.1.norm2.weight', 'decoder_blocks.1.norm2.bias', 'decoder_blocks.1.mlp.fc1.weight', 'decoder_blocks.1.mlp.fc1.bias', 'decoder_blocks.1.mlp.fc2.weight', 'decoder_blocks.1.mlp.fc2.bias', 'decoder_blocks.2.norm1.weight', 'decoder_blocks.2.norm1.bias', 'decoder_blocks.2.attn.qkv.weight', 'decoder_blocks.2.attn.qkv.bias', 'decoder_blocks.2.attn.proj.weight', 'decoder_blocks.2.attn.proj.bias', 'decoder_blocks.2.norm2.weight', 'decoder_blocks.2.norm2.bias', 'decoder_blocks.2.mlp.fc1.weight', 'decoder_blocks.2.mlp.fc1.bias', 'decoder_blocks.2.mlp.fc2.weight', 'decoder_blocks.2.mlp.fc2.bias', 'decoder_blocks.3.norm1.weight', 'decoder_blocks.3.norm1.bias', 'decoder_blocks.3.attn.qkv.weight', 'decoder_blocks.3.attn.qkv.bias', 'decoder_blocks.3.attn.proj.weight', 'decoder_blocks.3.attn.proj.bias', 'decoder_blocks.3.norm2.weight', 'decoder_blocks.3.norm2.bias', 'decoder_blocks.3.mlp.fc1.weight', 'decoder_blocks.3.mlp.fc1.bias', 'decoder_blocks.3.mlp.fc2.weight', 'decoder_blocks.3.mlp.fc2.bias', 'decoder_blocks.4.norm1.weight', 'decoder_blocks.4.norm1.bias', 'decoder_blocks.4.attn.qkv.weight', 'decoder_blocks.4.attn.qkv.bias', 'decoder_blocks.4.attn.proj.weight', 'decoder_blocks.4.attn.proj.bias', 'decoder_blocks.4.norm2.weight', 'decoder_blocks.4.norm2.bias', 'decoder_blocks.4.mlp.fc1.weight', 'decoder_blocks.4.mlp.fc1.bias', 'decoder_blocks.4.mlp.fc2.weight', 'decoder_blocks.4.mlp.fc2.bias', 'decoder_blocks.5.norm1.weight', 'decoder_blocks.5.norm1.bias', 'decoder_blocks.5.attn.qkv.weight', 'decoder_blocks.5.attn.qkv.bias', 'decoder_blocks.5.attn.proj.weight', 'decoder_blocks.5.attn.proj.bias', 'decoder_blocks.5.norm2.weight', 'decoder_blocks.5.norm2.bias', 'decoder_blocks.5.mlp.fc1.weight', 'decoder_blocks.5.mlp.fc1.bias', 'decoder_blocks.5.mlp.fc2.weight', 'decoder_blocks.5.mlp.fc2.bias', 'decoder_blocks.6.norm1.weight', 'decoder_blocks.6.norm1.bias', 'decoder_blocks.6.attn.qkv.weight', 'decoder_blocks.6.attn.qkv.bias', 'decoder_blocks.6.attn.proj.weight', 'decoder_blocks.6.attn.proj.bias', 'decoder_blocks.6.norm2.weight', 'decoder_blocks.6.norm2.bias', 'decoder_blocks.6.mlp.fc1.weight', 'decoder_blocks.6.mlp.fc1.bias', 'decoder_blocks.6.mlp.fc2.weight', 'decoder_blocks.6.mlp.fc2.bias', 'decoder_blocks.7.norm1.weight', 'decoder_blocks.7.norm1.bias', 'decoder_blocks.7.attn.qkv.weight', 'decoder_blocks.7.attn.qkv.bias', 'decoder_blocks.7.attn.proj.weight', 'decoder_blocks.7.attn.proj.bias', 'decoder_blocks.7.norm2.weight', 'decoder_blocks.7.norm2.bias', 'decoder_blocks.7.mlp.fc1.weight', 'decoder_blocks.7.mlp.fc1.bias', 'decoder_blocks.7.mlp.fc2.weight', 'decoder_blocks.7.mlp.fc2.bias', 'decoder_norm.weight', 'decoder_norm.bias', 'decoder_pred.weight', 'decoder_pred.bias'])

This is using --model 'vit_base_patch16' as instructed.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.