Giter Site home page Giter Site logo

cirl's Introduction

CIRL

This repo provides a demo for the CVPR 2022 paper "Causality Inspired Representation Learning for Domain Generalization" on the PACS dataset.

Requirements

  • Python 3.6
  • Pytorch 1.1.0

Training from scratch

Please first download the PACS dataset from http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 or from https://pan.baidu.com/s/1KxMA6SiQX1jdRxwkeKMqOw (password:pacs). Then update the files with suffix _train.txt and _val.txt in data/datalists for each domain, following styles below:

/home/user/data/images/PACS/kfold/art_painting/dog/pic_001.jpg 0
/home/user/data/images/PACS/kfold/art_painting/dog/pic_002.jpg 0
/home/user/data/images/PACS/kfold/art_painting/dog/pic_003.jpg 0
...

Please make sure you are using the official train-val-split. Once the data is prepared, then remember to update the path of train&val files and output logs in shell_train.py:

input_dir = 'path/to/train/files'
output_dir = 'path/to/output/logs'

Then running the code:

python shell_train.py -d=art_painting

Use the argument -d to specify the held-out target domain.

Evaluation

After training the model, firstly create directory ckpt/ and drag your model under it. For running the evaluation code, please update the files with suffix _test.txt in data/datalists for each domain, following the same styles as the train/val files above.

Then update the path of test files and output logs in shell_test.py:

input_dir = 'path/to/test/files'
output_dir = 'path/to/output/logs'

then simply run:

 python shell_test.py -d=art_painting

You can use the argument -d to specify the held-out target domain.

Acknowledgements

Some codes are adapted from FACT. We thank them for their excellent projects.

Contact

If you have any problem about our code, feel free to contact [email protected] or describe your problem in Issues.

cirl's People

Contributors

feimo49 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cirl's Issues

On the issue of model selection

CIRL/train.py

Line 192 in 1b50cd6

if self.results['test'][self.current_epoch] >= self.best_acc:

Choose the best model in your training code based on the correctness of the test set rather than the best correctness of the validation set.

What kind of question is that

res = func(*args, **kwrds)
File "mkl_fft_pydfti.pyx", line 1105, in mkl_fft._pydfti.fftn
File "mkl_fft_pydfti.pyx", line 1083, in mkl_fft._pydfti._fftnd_impl
File "mkl_fft_pydfti.pyx", line 952, in mkl_fft._pydfti.iter_complementary
File "mkl_fft_pydfti.pyx", line 985, in mkl_fft._pydfti._direct_fftnd
AssertionError

Reproduced experiment results

Is there anyone who has reproduced the reported results? Does the code need to be modified?

Only 48% acc on: Leave-one-domain-out(photo), PACS official split, default parameters of ResNet50.
Test result:
image

Part of the training log:
image

About the Fourier transform of causality intervention module

Hi author, thank you for your excellent work.
I have a question about the paper and code. The Causal Intervention Module in the paper says that $x^a$ is obtained by using Fourier transform, but in your code, I see the data transformation in the below function, which seems to be not Fourier transform. I am not an expert in the image field, so I have some doubts about this. In addition, I have not seen the paper in the data transformation λ~U(0,η) either.

def get_pre_transform(image_size=224, crop=False, jitter=0):
    if crop:
        img_transform = [transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
    else:
        img_transform = [transforms.Resize((image_size, image_size))]
    if jitter > 0:
        img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                    contrast=jitter,
                                                    saturation=jitter,
                                                    hue=min(0.5, jitter)))
    img_transform += [transforms.RandomHorizontalFlip(), lambda x: np.asarray(x)]
    img_transform = transforms.Compose(img_transform)
    return img_transform

Looking forward to your reply.

Unresolved reference 'get_copy_dataset'

Thanks for your excellent work, I found a problem: the get_copy_dataset() function seems to be undefined in DGDataLoader.py - Unresolved reference 'get_copy_dataset' ,Looking forward to your response

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.