Giter Site home page Giter Site logo

mcd_da's Introduction

Maximum Classifier Discrepancy for Domain Adaptation


This is the implementation of Maximum Classifier Discrepancy for digits classification and semantic segmentation in Pytorch. The code is written by Kuniaki Saito. The work was accepted by CVPR 2018 Oral.

Maximum Classifier Discrepancy for Domain Adaptation: [Project][Paper (arxiv)].


Getting Started

Go to classification or segmentation folder and see the instruction for each task.

Citation

If you use this code for your research, please cite our papers (This will be updated when cvpr paper is publicized).

@article{saito2017maximum,
  title={Maximum Classifier Discrepancy for Unsupervised Domain Adaptation},
  author={Saito, Kuniaki and Watanabe, Kohei and Ushiku, Yoshitaka and Harada, Tatsuya},
  journal={arXiv preprint arXiv:1712.02560},
  year={2017}
}

mcd_da's People

Contributors

ksaito-ut avatar kunisaito avatar littlewat avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mcd_da's Issues

sharing svhn2mnist result

hello, can you share your result of svhn2mnist accuracy?

I set svhn as the source and mnist as the target. Is it normal that the target accuracy is higher?

Reproduce the segmentation result

I have tried to rerun the segmentation part of the code in python 2.7
and I was unable to reproduce the result in the paper.
I have noticed several differences.

  1. I have rewritten the label map part in the GTA dataset since GTA needs to map to Cityscape and the original code didn't do it.
  2. I was using torch 1.3.1
  3. I update the loss to NLLLoss rather than NLLLoss2d
  4. I fixed a few grammar bugs which probably happened because of the change of time in the code
    here is my result:

image

I would love to see if anyone was able to reproduce the segmentation result, and I would be really appreciated if you could give me that piece of code such that I could see where went wrong.

THANKS, A LOT!

About toy dataset

Where can I find the code of the toy dataset, which is figure 4 in the paper ??

Replicating the results

Does anybody succeded in getting the results as the paper claimed?
So far by running the VisDA experiment with ResNet101 using default settings I get this:
screenshot from 2019-01-28 10-39-57
Mean class accuracy 62.54% (v.s 71.9 reported in paper)

drn_d_105-12b40979.pth

(pytorch1.1) zgm@zgm-icv:~/Lufei/MCD_DA/segmentation$ python adapt_trainer.py gta city --net drn_d_105 Downloading: "https://tigress-web.princeton.edu/~fy/drn/models/drn_d_105-12b40979.pth" to /home/zgm/.cache/torch/checkpoints/drn_d_105-12b40979.pth Traceback (most recent call last): File "adapt_trainer.py", line 69, in <module> is_data_parallel=args.is_data_parallel) File "/home/zgm/Lufei/MCD_DA/segmentation/models/model_util.py", line 51, in get_models model_list = get_MCD_model_list() File "/home/zgm/Lufei/MCD_DA/segmentation/models/model_util.py", line 41, in get_MCD_model_list model_g = DRNSegBase(model_name=net_name, n_class=n_class, input_ch=input_ch) File "/home/zgm/Lufei/MCD_DA/segmentation/models/dilated_fcn.py", line 109, in __init__ pretrained=pretrained, num_classes=1000, input_ch=input_ch) File "/home/zgm/Lufei/MCD_DA/segmentation/models/drn.py", line 343, in drn_d_105 model.load_state_dict(model_zoo.load_url(model_urls['drn-d-105'])) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/site-packages/torch/hub.py", line 439, in load_state_dict_from_url _download_url_to_file(url, cached_file, hash_prefix, progress=progress) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/site-packages/torch/hub.py", line 354, in _download_url_to_file u = urllib.request.urlopen(req).read() File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/urllib/request.py", line 223, in urlopen return opener.open(url, data, timeout) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/urllib/request.py", line 532, in open response = meth(req, response) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/urllib/request.py", line 642, in http_response 'http', request, response, code, msg, hdrs) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/urllib/request.py", line 570, in error return self._call_chain(*args) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/urllib/request.py", line 504, in _call_chain result = func(*args) File "/home/zgm/.conda/envs/pytorch1.1/lib/python3.6/urllib/request.py", line 650, in http_error_default raise HTTPError(req.full_url, code, msg, hdrs, fp) urllib.error.HTTPError: HTTP Error 403: Forbidden
So can you provide the drn_d_105-12b40979.pth, thanks so much!

Segmentation labels

Thank you for sharing the code! I am trying to run the segmentation task, and I am really confused about the segmentation labels.

What code do you run to transfer the original colored ground truth images of GTAV into the grayscale images with the 20 classes that you report on?
The same question about Cityscrapes - the dataset has 33 classes. How do you turn them into 20?

cityscapes/info.json

Thank for your code ! I am trying to run the segmentation task. But there is error :
File "eval.py", line 223, in eval_city
with open(join(devkit_dir, 'data', dset, 'info.json'), 'r') as fp:
IOError: [Errno 2] No such file or directory: '/data/ugui0/dataset/adaptation/taskcv-2017-public/segmentation/data/cityscapes/info.json'
can you help me ?

Model Selection

The classification code seems to be running "test" on every epoch and is printing the test accuracy.
Which of these accuracies do you report?
How is the model selection done?

I understand each experiment is repeated 5 times. But each time, is the last epoch accuracy considered for the mean or the max(accuracy at each epoch)?

Loss is becoming negative

Very nice work! During my training, I found that loss can become negative:

Train Epoch: 198 [0/100 (0%)]	Loss1: 0.024868	 Loss2: 0.022132	  Discrepancy: 0.018226

Test set: Average loss: -0.0588, Accuracy C1: 9449/10000 (94%) Accuracy C2: 9509/10000 (95%) Accuracy Ensemble: 9554/10000 (96%) 

recording record/usps_mnist_k_4_alluse_no_onestep_False_1_test.txt
Train Epoch: 199 [0/100 (0%)]	Loss1: 0.012343	 Loss2: 0.020431	  Discrepancy: 0.030520

Test set: Average loss: -0.0581, Accuracy C1: 9419/10000 (94%) Accuracy C2: 9518/10000 (95%) Accuracy Ensemble: 9537/10000 (95%) 

recording record/usps_mnist_k_4_alluse_no_onestep_False_1_test.txt

Do you think this is normal?

Asking for the visualization code

Congratulations, you have done a great job.

am confused that whether you can publish the codes about visualization in figure 4? As hown bellow:

image

Thank you very much. Looking forward to your replay.

Question about the handwritten digit experiment

In Table 1 the paper shows very high scores (more than 90%) in the digit classification tasks. So the question is that the testing set is merely the target-domain testing set or the combination of both testing sets in source and target domains. For example, in the unsupervised domain adaptation task of SVHN--> MNIST, the classification score can reach as high as 96.2%. Is the testing set for getting 96.2% ONLY MNIST TEST SET or the combination of SVHN TEST SET and MNIST TEST SET? Thanks in advance.

synth traffic dataset

Hi, could you share the files of synth traffic dataset or where I can download it? I can't find this dataset in the Internet. Thanks.

About the classifiers.

Thanks for the code !

I have a question about the classifier. What is the differences between F1 and F2 classifier?
I can't find any differences in the code of classification folder.
Can you tell me the differences on code implementation ?

gta train val files

Just wondering if you could provide the additional train/val . txt files for gta.
The code needs to read "train.txt" and "val.txt" in

"gta": "", ## Fill the directory over images folder. put train.txt, val.txt in this folder

Are they the ids stored in the split.mat@gta and saved externally as .txt?

Missing Loss_weight file

image
For Segmentation task: It seems like that there is a class weight file but I don't find this file. I am wondering that if you can help me or does I have something wrong?

USPS->Mnist source-only result can't reach 0.634

The paper said that the source-only accuracy of USPS->Mnist is 0.634. But I can only get the accuracy of nearly 0.2~0.3 in USPS->Mnist by using the same network. Does anyone get the same source-only result with me? How to solve it?

[Request] fcnvgg

I founded that there is no vgg_fcn file which is defined in models/model_util.py.
I would really appreciate it if you could share the network for FCN8s, since I'm not really sure about how to split the base FCN8s network into SegBase and SegPixelClassifier.
Thank you in advance.

what does ReLabel mean?

ReLabel(255, args.n_class - 1),  # Last Class is "Void" or "Background" class

actually, the args.n_class - 1 means 'bicyle' in cityscapes.

so you will map all unsure pixels to bicycle?

Some problems with building my own datasets on classificaion

Thank you very much for the code you provided.Recently, I have been studied the dataset you provided for classification. Mnist serves as the source domain data and usps serves as the target domain data.However, when I build my own data to replace mnist and usps data, there is always an error. I hope to get your help.Error is shown as follows:

   IndexError: index 149 is out of bounds for axis  0 with  size 1 

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.