Giter Site home page Giter Site logo

kaiyangzhou / dassl.pytorch Goto Github PK

View Code? Open in Web Editor NEW
1.1K 1.1K 162.0 465 KB

A PyTorch toolbox for domain generalization, domain adaptation and semi-supervised learning.

License: MIT License

Python 99.76% Shell 0.24%
artificial-intelligence benchmark-datasets computer-vision deep-learning deep-neural-networks domain-adaptation domain-generalization machine-learning pytorch semi-supervised-learning

dassl.pytorch's People

Contributors

kaiyangzhou avatar sansyo avatar siaimes avatar wyf0912 avatar ybzh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dassl.pytorch's Issues

ADDA on miniDomainNet

Hi,
I tried ADDA on miniDomainNet, but the acc on the target test set decrease from 0.29 to 0.01. Previously I tried it on office31 and it works fine. Do you know what is wrong?
I used the following scripts:

python tools/train.py
--root $DATA \
--trainer SourceOnly \
--source-domains real \ 
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidomainnet
python train.py --trainer ADDA --source-domains real --target-domain sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/adda_minidomainnet \
--init-weights output/source_only_minidomainnet/model/model.pth.tar-60

about DAEL

Hi, thanks for sharing your code.

I want to use your DAEL model, but can ADEL support the scene that multi-source with different categories (such as category-shift problem solved in "Deep Cocktail Network: Multi-source Unsupervised Domain Adaptation with Category Shift").

Thank you very much!

How to run in multi-gpus?

I use 4 gpus to train DDAIG, but the error happens "CUDA OUT OF MEMORY".
It seems the code only supports single-gpu training.

Set different learning rate

Thank you for your great work!

I have a problem would like to ask.
Is it possible to give two different parts in one model with different learning rates?

The implenentation of cross-entropy loss of domain adaptive ensemble learning

Hi,

Thanks for your code and this is a great work.

I have read the paper and as described in the section 3, the loss function for domain-specific expert learning is a cross-entropy loss.

I guess the implementation of the loss of domain adaptive ensemble learning is as following:
https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/dg/daeldg.py#L109

Is this a standard cross-entropy loss function? Why not use nn.CrossEntropyLoss()?

And why calculate the mean when calculating the cross entropy of each data sample?
(-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()

confusion matrix

def reset(self):
self._correct = 0
self._total = 0
if self._per_class_res is not None:
self._per_class_res = defaultdict(list)

I think we should also reset the self._y_true and self._y_pred in the reset function.
Otherwise we'll get wrong confusion matrix.

How to run benchmarks

Hello Kaiyang,
thank you for sharing the codes.

Is there any guide to run benchmarks such as MME, MCD, et al.
Many thanks for your reply.

Cannot reproduce Vanilla on Office-Home

Thanks for your code.

I have tried to reproduce the Vanilla model (Resnet-18) on Office-Home dataset. I got 47.2% on Clipart, which is far away from the proposed results (49.4%) in your paper.

Can you share your training parameter about this task, including how to set the optimizer, network and data-preprocessing?

Low accuracy on SelfEnsembling; mnist -> mnist_m;

Hi, thanks for sharing your code.

I tried it out trainer=SelfEnsembling, source_domain=mnist and traget_domain=mnist_m and I was expecting to get around 95% accuracy on the test subset, target_domain.
But I wasn't able to get more than 65% Accuracy.

Can you please have a look if I am missing any important parameters? I run it like this:

        python tools/train.py \
            --backbone resnet18 \
            --root "datasets" \
            --trainer SelfEnsembling \
            --source-domains "mnist" \
            --target-domains "mnist_m" \
            --output-dir "$job_dir" \
            --dataset-config-file "configs/datasets/da/digit5.yaml" \
            DATALOADER.K_TRANSFORMS 2 \
            DATALOADER.TRAIN_X.BATCH_SIZE 128 \
            DATALOADER.NUM_WORKERS 10 \
            TRAINER.SE.EMA_ALPHA 0.999 \
            OPTIM.LR 0.0003 \
            OPTIM.MAX_EPOCH 200 \

I also did a small hyper-parameter sweep and tried different LR=(3e-3 3e-4 3e-5) and EMA_ALPHA=(0.99 0.999 0.9999), but I didn't find a combination with a better score.

Installing this repository with pip

Hi there, when installing this repository with pip install -r requirements.txt (where the requirements.txt file contains git+https://github.com/KaiyangZhou/Dassl.pytorch.git), the import numpy as np in setup.py throws a ModuleNotFoundError, because numpy is not yet installed at that moment. Is is it possible to remove the import numpy as np statement and def numpy_include(): ... from setup.py such that this repository can be installed automatically with pip?

Files not found for DAEL Digits5

Hello,

I replicated the dataset folder structure and tried running DAEL on digit5. However, it cannot find the mnist files, and seems to be searching for the wrong files. Upon closer inspection the load_mnist function in /data/datasets/da/digit5.py seems to expect different files from those provided in the links in DATASETS.md. Is it possible that this file is outdated with respect to the readme?

Edit: I had not run the dataset creation script. My apologies

cannot reproduce DDAIG on PACS

Thanks for sharing the code.

However, even if I use your latest config file for PACS, I still cannot reproduce the results in the paper.

I repeat the experiments for 5-6 times in each domain. The results are shown in the following table:

art sketch photo cartoon
  76.91   72.23
79.35 74.39 93.05 73.17
79.64 74.13 94.91 74.32
82.71 71.61 92.81 73.12
80.03 75.41 94.01 72.14
82.37 75.1 94.01 72.1

Hope to get your reply.

Reproduce DDAIG

I received an email saying the current code cannot reproduce the results of DDAIG on PACS. I haven't run DDAIG using Dassl so I'm not sure if there is an issue.

I've attached the original log files which contain the information on versions of libraries, the environmental setting, and the exact parameters used in the paper. Hope this could help. Please check this google drive link. As DDAIG was done in early 2019, at that time I was using torch=0.4.1 and numpy=1.14.5. Not sure if this will cause an issue. If there is really an issue with reproduction, it's also possible that there was sth wrong when I transferred DDAIG's code to this public Dassl repo (I'll double check this).

Please note that DDAIG was named ddap in the log files. Some parameters' names are different from Dassl's, this is because the original code was a baby-version of Dassl. But they should be easy to understand.

I'll find time and resources to run DDAIG using this code (pls bear with me).

warnings.warn('No file found at "{}"'.format(fpath))

Following the steps in README, there still are some errors to run the demo.
The path of dataset maybe have errors, but I have checked carefully, and the "file not found" errors still exist.
image
You can see the files can be found in "/home/wyk/dataset/office31/amazon/images/ruler"
image
How to run a demo on dassl.pytorch? Looking forward to your reply!

the loss of daeldg

loss /= self.n_domain
loss_cr /= self.n_domain
acc /= self.n_domain
loss = 0
loss += loss
loss += loss_cr
self.model_backward_and_update(loss)

Hi, I think there is something wrong with the loss in daeldg.py.

Why the 'loss' is assigned to 0? This will make the supervised loss be invalid.

The parameters of ''Evaluation on Heterogeneous DG''

Hi there!

In the Evaluation on Heterogeneous DG part of your paper, you evaluate the approach on the cross-dataset person re-identification (re-ID) task.
Can you share your training parameter about this task, including how to set the optimizer, network and data-preprocessing?

Thank you very much!

i can't run mixmatch.py

There is only fixmatch fold in 'Dassl.pytorch/configs/trainers/ssl/'
when i choose 'fixmatch/cifar10.yaml' as config-file and run this:

CUDA_VISIBLE_DEVICES=2 python tools/train.py --root 'ssl/data/' --trainer MixMatch --dataset-config-file configs/datasets/ssl/cifar10.yaml --config-file configs/trainers/ssl/fixmatch/cifar10.yaml --output-dir output/mixmatch

and get this: assert cfg.DATALOADER.K_TRANSFORMS > 1
(because cfg.DATALOADER.K_TRANSFORMS==1 in default)

so could you give a mixmatch config-file? Thanks a lot :)

use adabn.py

At first, thanks for your work, it's a nice project in DA, DG filed. I'm using your adabn.py to deploy AdaBN method on my model, I read the code, found it's pretty simple and I think maybe there is somthing missing?
From my understanding, AdaBN method is to use the mean and var of the target domain to replace the ones from the training
stage on BN layers, but I just found reset running stats function in the adabn.py.

I'm looking forwar to your reply, thanks!

Can I use the Dassl in Horovod frame?

Hello, kaiyang. I'm trying to use your Dassl in some medical tasks. Due to the limit of my hardware, I have to train the network from two distruibuted GPUs...I'm using Horovod to finish the task, but I noticed that:"A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training", is that true while using Horovod ? I'm looking forward to have your answer, because it takes time to prepare the environment.

Terra Incognita

Someone suggested we should add the dataset of Terra Incognita, which is a wildlife animal classification dataset used in the DomainBed paper. I had gone through the images of Terra Incognita (the four locations chosen by DomainBed) and found that the objects of interest, i.e., animals, are often small in scale in comparison to the whole image, partially visible, and not centered. I feel using this dataset for evaluating image classifiers won't help track the progress as the quality isn't good enough.

RuntimeError Occur When Using DDAIG Model

RuntimeError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 11.91 GiB total capacity; 11.04 GiB already allocated; 43.62 MiB free; 11.10 GiB reserved in total by PyTorch)
I used 2 12GB TiTan XP run the Code. I wonder if there's any problem with my own code or is the problem of my devices.
Thanks.

DDAIG: different hyperparameters for different domains violates the DG settings

I notice that in #12 (comment)
The DDAIG (AAAI 2019) method uses different hyperparameters on different domains, which violates the original setting of domain generalization (The test domain is unseen.).
Fine tuning the hyperparameter for each domain will indirectly use the unseen target to improve the performance.

The real improvement of DDAIG method is unclear. And it is an unfair comparison with other SOTA results!

A new implementation of MixStyle

We have improved the implementation of MixStyle to make it more flexible.

Recall that MixStyle has two versions: random mixing and cross-domain mixing. The former randomly shuffles the batch dimension while the latter mixes the 1st half in a batch with the 2nd half.

After merging MixStyle2 to MixStyle, the two versions are now managed by a new variable called self.mix, which takes as input either random or crossdomain that correspond to the two versions respectively. This variable can be set during initialization, e.g., self.mixstyle = MixStyle(mix='random'). It can also be changed on-the-fly. For instance, say you wanna apply random mixing at current step, simply do model.apply(random_mixstyle), or model.apply(crossdomain_mixstyle) if you prefer the cross-domain mixing manner.

We have also added new context managers to manage mixstyle in the forward pass. Say your model has MixStyle layers which were initially activated and you would like to deactivate them at a certain time, you can do

# print(MixStyle._activated): True
with run_without_mixstyle(model):
    # print(MixStyle._activated): False
    output = model(input)
# print(MixStyle._activated): True

Otherwise if you want to use MixStyle layers which were initially deactivated, you can do

# print(MixStyle._activated): False
with run_with_mixstyle(model):
    # print(MixStyle._activated): True
    output = model(input)
# print(MixStyle._activated): False

You can also change self.mix while using run_with_mixstyle, e.g.

# print(MixStyle._activated): False
# print(MixStyle.mix): random
with run_with_mixstyle(model, mix='crossdomain'):
    # print(MixStyle._activated): True
    # print(MixStyle.mix): crossdomain
    output = model(input)
# print(MixStyle._activated): False
# print(MixStyle.mix): crossdomain

But note that the change in self.mix during run_with_mixstyle is permanent unless you manually use model.apply(random_mixstyle) or model.apply(crossdomain_mixstyle) to modify the variable.

Replication Parameters and How to use MixStyle and EFDM

Dear Kaiyang

Really appreciate the open-source domain generalization framework. It is really amazing.
I'm currently working on replicating the results and extending my work on the current framework. Do you mind sharing the parameters for each baseline to replicate the results mentioned in your AAAI-2020 (DDAIG)? When I tried to replicate the baseline results of CrossGrad and DomainMix, the results were much worse than the paper. I guess it may be caused by the parameter tunning because I'm currently using the default settings of the framework (I didn't change any single line of the framework). In the config files, there are only configs about DDAIG, DAELDG, and Vanilla. (All these three are very good).

On the other hand, could you please post an instruction about how to run MixStyle and EFDM by the framework?

Thank you very much!

Feature Request: Implementation of SSDA like data loaders

Hi,
Is there any plan to introduce the semi-supervised domain adaptation data loaders in the code? There are UDA and SSL loaders, but for SSDA, we might need different target data loaders during the training, which, as far as I know, cannot be directly used from the codebase.

If the implementation of ssda loaders will help, I can make a PR as I have worked on it

Accuracy of M3SDA on DomainNet

Hi, thanks for sharing the nice codes.

I have some trouble to get accuracy of M3SDA on DomainNet.
With below command, I got 0.55% of accuracy (error:99.45%).
What`s wrong with this ?

python tools/train.py --root /database --trainer M3SDA --source -domains clipart --target-domains infograph --dataset-config-file configs/datasets/da/domainnet.yaml --config-file configs/trainers /da/m3sda/domainnet.yaml --output-dir output/M3SDA_CI_DOMAINNET 

Custom Data?

Hi Kaiyang,

Thanks for this great repo! I am interested in using dassl to build several baselines for my own dataset, which is a set of tabular (vector) data from various domains. Do you have suggestions/tips/recommendations for adding our own dataset?

Thanks a lot!

Issue with FixMatch training on customs datasets

Hello, I am having trouble training FixMatch with custom datasets.
I got this error:
input_x2 = batch_x["img2"] KeyError: 'img2.

As part of the config/trainer files, I set K_TRANSFORMS to 2. When DatasetWrapper.getitem() is called, it returns one image.

TypeError: unsupported format string passed to NoneType.__format__

Loading evaluator: Classification
Traceback (most recent call last):
File "tools/train.py", line 191, in
main(args)
File "tools/train.py", line 110, in main
trainer.load_model(args.model_dir, epoch=args.load_epoch)
File "d:\coop-main\dassl.pytorch-master\dassl\engine\trainer.py", line 199, in load_model
f"Load {checkpoint} to {name} (epoch={epoch}, val_result={val_result:.1f})"
TypeError: unsupported format string passed to NoneType.format

terminal input:python tools/train.py --root datasets/da/ --trainer SourceOnly --dataset-config-file configs/datasets/da/visda17.yaml --config-file configs/trainers/da/source_only/visda17.yaml --output-dir output/office31_test --source-domains real --target-domains real --eval-only --model-dir output/office31 --load-epoch 2

when i used .pth to test dataset, which occured this matter,help me plz

question about 'not file found'

Thank you share your code! I am a beginner . I ran office31 dataset, showed file not found. But I have this file.

And my sturcture list:

The inconsistency between results after saving and loading model

Hi there,

Thanks for your code. There is an issue when testing on the saved model.

I run the testing soon after training and the testing result is good. However, when I loaded the saved model to run the testing again, the result is very low.

How can I get the consistent testing results?

Run SimCLR using Dassl

Hi Kaiyang

Can you please point me to how can I implement a two-view dataloader for training using the SimCLR loss. Basically, I want the train dataloader to return two views (augmentations) of the same image each time it is called.

Thanks!

ImportError: dlopen: cannot load any more object with static TLS

hi,kaiyang,When I was installing dassl using the installation steps you provided, after completing all the steps and running the code of clipadapter, the following error was thrown:

Traceback (most recent call last):
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/__check_build/init.py", line 48, in
from ._check_build import check_build # noqa
ImportError: dlopen: cannot load any more object with static TLS

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train.py", line 10, in
from dassl.engine import build_trainer
File "/home1/pan-internship-6/Dassl.pytorch/dassl/engine/init.py", line 2, in
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
File "/home1/pan-internship-6/Dassl.pytorch/dassl/engine/trainer.py", line 19, in
from dassl.evaluation import build_evaluator
File "/home1/pan-internship-6/Dassl.pytorch/dassl/evaluation/init.py", line 3, in
from .evaluator import EvaluatorBase, Classification
File "/home1/pan-internship-6/Dassl.pytorch/dassl/evaluation/evaluator.py", line 1, in
from sklearn.metrics import f1_score,confusion_matrix
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/init.py", line 81, in
from . import __check_build # noqa: F401
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/__check_build/init.py", line 50, in
raise_build_error(e)
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/__check_build/init.py", line 31, in raise_build_error
raise ImportError(
ImportError: dlopen: cannot load any more object with static TLS

How can I fix this error?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.