Giter Site home page Giter Site logo

adda's Introduction

Adversarial Discriminative Domain Adaptation

Getting started

This code requires Python 3, and is implemented in Tensorflow.

Hopefully things should be fairly easy to run out of the box:

pip install -r requirements.txt
mkdir data snapshot
export PYTHONPATH="$PWD:$PYTHONPATH"
scripts/svhn-mnist.sh

The provided script does the following things:

  • Train a base LeNet model on SVHN (downloading SVHN under data/svhn in the process)
  • Use ADDA to adapt the SVHN model to MNIST (downloading MNIST under data/mnist in the process)
  • Run an evaluation on MNIST using the source-only model (stored at snapshot/lenet_svhn)
  • Run an evaluation on MNIST using the ADDA model (stored at snapshot/adda_lenet_svhn_mnist)

Areas of interest

  • Check scripts/svhn-mnist.sh for hyperparameters.
  • The LeNet model definition is in adda/models/lenet.py.
  • The model is annotated with data preprocessing info, which is used in the preprocessing function in adda/models/model.py.
  • The main ADDA logic happens in tools/train_adda.py.
  • The adversarial discriminator model definition is in adda/adversary.py.

adda's People

Contributors

erictzeng avatar jhoffman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

adda's Issues

Error

I am getting the following error:

name: GeForce GTX 1070
major: 6 minor: 1 memoryClockRate (GHz) 1.683
pciBusID 0000:01:00.0
Total memory: 7.92GiB
Free memory: 7.43GiB
I tensorflow/core/common_runtime/gpu/gpu_device.cc:906] DMA: 0
I tensorflow/core/common_runtime/gpu/gpu_device.cc:916] 0: Y
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
F tensorflow/stream_executor/cuda/cuda_dnn.cc:222] Check failed: s.ok() could not find cudnnCreate in cudnn DSO; dlerror: /home/mahfuj/.local/lib/python3.5/site-packages/tensorflow/python/_pywrap_tensorflow.so: undefined symbol: cudnnCreate
scripts/svhn-mnist.sh: line 13: 3379 Aborted (core dumped) python tools/train.py svhn train lenet lenet_svhn --iterations 10000 --batch_size 128 --display 10 --lr 0.001 --snapshot 5000 --solver adam

Custom Dataset

How to give custom dataset using .jpg or .png image files? Which .py file should I modify to give my inputs?

Thanks in advance,

vgg16 did not run

‘’‘
raceback (most recent call last):
File "tools/train.py", line 118, in
main()
File "/home/work/anaconda3/envs/qinqinglin/lib/python3.5/site-packages/click/core.py", line 722, in call
return self.main(*args, **kwargs)
File "/home/work/anaconda3/envs/qinqinglin/lib/python3.5/site-packages/click/core.py", line 697, in main
rv = self.invoke(ctx)
File "/home/work/anaconda3/envs/qinqinglin/lib/python3.5/site-packages/click/core.py", line 895, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/work/anaconda3/envs/qinqinglin/lib/python3.5/site-packages/click/core.py", line 535, in invoke
return callback(*args, **kwargs)
File "tools/train.py", line 62, in main
class_loss = tf.losses.sparse_softmax_cross_entropy(label_batch, net)
File "/home/work/anaconda3/envs/qinqinglin/lib/python3.5/site-packages/tensorflow/python/ops/losses/losses_impl.py", line 915, in sparse_softmax_cross_entropy
name="xentropy")
File "/home/work/anaconda3/envs/qinqinglin/lib/python3.5/site-packages/tensorflow/python/ops/nn_ops.py", line 2039, in sparse_softmax_cross_entropy_with_logits
(labels_static_shape.ndims, logits.get_shape().ndims))
ValueError: Rank mismatch: Rank of labels (received 1) should equal rank of logits minus 1 (received 4)
’‘’

Results not so stable in SVHN to MNIST

Hello, I am trying to rerun the code from SVHN to MNIST, however, I find that the results are not so robust when I try to rerun the code for several times using the default settings. The results(accuracy) seems to fluctuate from 0.68 to 0.79. Are there any suggestions about how to make it stable?

Why my script would stop running without any error report?

Thanks for code sharing!

I used windows system to run it in command window. When running train_adda.py it stops after several iterations, usually 2 to 16, without any error report. It will stop there with nothing coming up. I changed my iteration number to 100 but it never successfully completed. Anyone knows why?

image

This is a screenshot of my command window.

Running on Office31

Thanks for your source code.

I am trying to run it with the office dataset according to the setting in the paper using AlexNet, but the result is only getting worse during training. Could you please also release the code for office dataset?

Thank you very much.

Appropriate layer to adapt

Hi,
Please do correct me if I am wrong. From the code, it seems that we will adapt the extreme last(10-way classification output) fully connected layer. However, I think that this layer is obviously giving the distribution of class probabilities and we cannot adapt this layer if we don't know the label of the target dataset during ADDA.
For example, if we feed a digit 5 from source domain and a digit 7 from the target domain, obviously the distributions of the last layer will be different. It only makes sense to align the last layer if we are sure that we feed the images of same class; which in turn means that we have to know the class labels of target domain during the ADDA adaptation phase.
Is my understanding wrong ?

Adversarial loss

Thanks for sharing the source code!

But I have some questions here:

  1. In adda/tools/train_adda.py line88-103: seems that you set source domain label as 1, and target domain label as 0? But here source distribution is fixed, while target distribution is updated. According to the original setting in the GAN paper, the changing one should have the zero label. And the equation (9) in the ADDA paper also shows that the ground truth label of source should be 1.

  2. How do you separate your encoder and classifier in the base model? In the paper, it seems that the encoder only contains CNN. But, in the code seems that you take the whole network as a encoder, then where is the classifier?

Thanks for your time. Looking forward to your response.

AttributeError: 'MNIST' object has no attribute 'ignore_labels'

in line 122 in the eval_segmentation.py :im_intersection, im_union = count_intersection_and_union(
predictions[0], gt[0], dataset.num_classes,
ignore=dataset.ignore_labels)
running with the problem—— AttributeError: 'MNIST' object has no attribute 'ignore_labels'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.