Giter Site home page Giter Site logo

svg's People

Contributors

edenton avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

svg's Issues

cuda memory out

Hello, thanks for your work. However, when I start to train your KTH model, I have 'cuda memory out' problem. Could you please explain what kind of training setting you used ? Have you trained on multiple GPUs ?

The results of using FP model for MMNIST are really bad

My results of using fix-prior model for MMNIST generation are really bad. As the following gif shows, after 25 epochs. I have tried the code with over 100 epochs, but the results don't improve much.

I don't find the comparison of FP and LP over MMNIST in the paper. What's your results in such case? Anybody get good results using FP on MMNIST? Thanks!

sample_25

fp-SVG's plot() ignores the last past frame + a fix suggestion

Hi,

The plot() function in train_svg_fp.py has an issue of disregarding the last frame of the past frames, before moving on to predicting future frames.

At the end of if i < opt.n_past: block (<for i in range(1, opt.n_eval):<for s in range(nsample):<def plot(x, epoch):) (line 189), I suggest adding the following lines:

h, skip= h_seq[i]
h = h.detach()

With this change the following will happen:
h will be overwritten by h, skip = h_seq[i-1] in line 180, until i < opt.n_past. However, when the loop updates/iterates i from opt.n_past - 1 to opt.n_past, h and skip from the last of the past frames will be used in the else: block starting in line 190.

lp-SVG code does not suffer from this issue, but skip connection still comes from a second last frame of the past frames, which I do not think is a big issue (although may not be ideal).

Reproducibility

Hi

I am not able to reproduce the result for KTH from your paper.
The settings I used:

python train_svg_fp.py --batch_size 16 --dataset kth --image_width 64 --model vgg --g_dim 128 --z_dim 24 --beta 0.000001 --n_past 10 --n_future 10 --channels 1 --lr 0.0008 --data_root /path/to/data/ --log_dir /logs/will/be/saved/here/
The only change I made was the batch_size, which was set to 100 by default and did not fit into my memory.

Is there anything I have missed to properly reproduce your results?

KL loss doesn't decrease

Hi,
Thanks for sharing the codes, nice work!
When I train the learned prior model on smmnist dataset, the reconstruction loss decreases;
However, the KL loss kept increasing. Is it normal?

Thanks.

-

Edit: Invalid comment, misread a variable name.

Unable to download complete BAIR robot push dataset

sh data/download_bair.sh data/BAIR
ends with error HTTP request sent, awaiting response... 416 Requested range not satisfiable

But ends at different points in different trials. I mean, 1st trial, only 4MB was downloaded, in 2nd trial, 140MB was download and in 3rd trial 12MB was download before this error came. I got the same error everytime.

Full console output:
$ sh data/download_bair.sh data/BAIR2 --2019-05-31 11:33:06-- http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar Resolving rail.eecs.berkeley.edu (rail.eecs.berkeley.edu)... 128.32.189.73 Connecting to rail.eecs.berkeley.edu (rail.eecs.berkeley.edu)|128.32.189.73|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 32274964480 (30G) [application/x-tar] Saving to: ‘data/BAIR2/bair_robot_pushing_dataset_v0.tar’

bair_robot_pushing_ 0%[ ] 7.07M 150KB/s in 97s

2019-05-31 11:34:44 (74.3 KB/s) - Connection closed at byte 7409747. Retrying.

--2019-05-31 11:34:45-- (try: 2) http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar Connecting to rail.eecs.berkeley.edu (rail.eecs.berkeley.edu)|128.32.189.73|:80... connected. HTTP request sent, awaiting response... 416 Requested range not satisfiable

The file is already fully retrieved; nothing to do.

softmotion30_44k/ softmotion30_44k/test/ softmotion30_44k/test/traj_0_to_255.tfrecords tar: Unexpected EOF in archive tar: rmtlseek not stopped at a record boundary tar: Error is not recoverable: exiting now

Which version of torch to use?

Use PyTorch=0.3.1

I didn't know how else to document this. So, creating this issue. For reference, the below is the list of all packages I had to run generate_svg_lp.py for bair dataset on CPU

name: svglp
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _tflow_select=2.3.0=mkl
  - absl-py=0.7.1=py36_0
  - astor=0.7.1=py36_0
  - atomicwrites=1.3.0=py36_1
  - attrs=19.1.0=py36_1
  - blas=1.0=mkl
  - bzip2=1.0.8=h7b6447c_0
  - c-ares=1.15.0=h7b6447c_1
  - ca-certificates=2019.5.15=1
  - cairo=1.14.12=h8948797_3
  - certifi=2019.6.16=py36_1
  - cffi=1.12.3=py36h2e261b9_0
  - cloudpickle=1.0.0=py_0
  - cudatoolkit=8.0=3
  - cudnn=7.1.3=cuda8.0_0
  - cupti=8.0.61=0
  - cycler=0.10.0=py36_0
  - cytoolz=0.9.0.1=py36h14c3975_1
  - dask-core=1.2.2=py_0
  - dbus=1.13.6=h746ee38_0
  - decorator=4.4.0=py36_1
  - expat=2.2.6=he6710b0_0
  - fontconfig=2.13.0=h9420a91_0
  - freeglut=3.0.0=hf484d3e_5
  - freetype=2.9.1=h8a8886c_1
  - gast=0.2.2=py36_0
  - glib=2.56.2=hd408876_0
  - graphite2=1.3.13=h23475e2_0
  - grpcio=1.16.1=py36hf8bcb03_1
  - gst-plugins-base=1.14.0=hbbd80ab_1
  - gstreamer=1.14.0=hb453b48_1
  - h5py=2.8.0=py36h989c5e5_3
  - harfbuzz=1.8.8=hffaf4a1_0
  - hdf5=1.10.2=hba1933b_1
  - icu=58.2=h9c2bf20_1
  - imageio=2.5.0=py36_0
  - intel-openmp=2019.3=199
  - jasper=2.0.14=h07fcdf6_1
  - joblib=0.13.2=py36_0
  - jpeg=9b=h024ee3a_2
  - keras-applications=1.0.7=py_0
  - keras-preprocessing=1.0.9=py_0
  - kiwisolver=1.1.0=py36he6710b0_0
  - libedit=3.1.20181209=hc058e9b_0
  - libffi=3.2.1=hd88cf55_4
  - libgcc-ng=8.2.0=hdf63c60_1
  - libgfortran-ng=7.3.0=hdf63c60_0
  - libglu=9.0.0=hf484d3e_1
  - libopus=1.3=h7b6447c_0
  - libpng=1.6.37=hbc83047_0
  - libprotobuf=3.7.1=hd408876_0
  - libstdcxx-ng=8.2.0=hdf63c60_1
  - libtiff=4.0.10=h2733197_2
  - libuuid=1.0.3=h1bed415_2
  - libvpx=1.7.0=h439df22_0
  - libxcb=1.13=h1bed415_1
  - libxml2=2.9.9=he19cac6_0
  - markdown=3.1=py36_0
  - matplotlib=3.1.0=py36h5429711_0
  - mkl=2018.0.3=1
  - mkl_fft=1.0.10=py36_0
  - mkl_random=1.0.2=py36_0
  - mock=3.0.5=py36_0
  - more-itertools=7.0.0=py36_0
  - nccl=1.3.4=cuda8.0_1
  - ncurses=6.1=he6710b0_1
  - networkx=2.3=py_0
  - ninja=1.9.0=py36hfd86e86_0
  - numpy=1.15.4=py36h1d66e8a_0
  - numpy-base=1.15.4=py36h81de0dd_0
  - olefile=0.46=py36_0
  - openssl=1.1.1c=h7b6447c_1
  - pcre=8.43=he6710b0_0
  - pillow=6.0.0=py36h34e0f95_0
  - pip=19.1.1=py36_0
  - pixman=0.38.0=h7b6447c_0
  - pluggy=0.11.0=py_0
  - progressbar2=3.37.1=py36_0
  - protobuf=3.7.1=py36he6710b0_0
  - py=1.8.0=py36_0
  - pycparser=2.19=py36_0
  - pyparsing=2.4.0=py_0
  - pyqt=5.9.2=py36h05f1152_2
  - pytest=4.5.0=py36_0
  - pytest-runner=4.4=py_0
  - python=3.6.8=h0371630_0
  - python-dateutil=2.8.0=py36_0
  - python-utils=2.3.0=py36_0
  - pytorch=0.3.1=py36hfbe7015_1
  - pytz=2019.1=py_0
  - pywavelets=1.0.3=py36hdd07704_1
  - qt=5.9.7=h5867ecd_1
  - readline=7.0=h7b6447c_5
  - scikit-image=0.15.0=py36he6710b0_0
  - scikit-learn=0.20.1=py36h4989274_0
  - scikit-video=1.1.11=pyh24bf2e0_0
  - scipy=1.1.0=py36hfa4b5c9_1
  - setuptools=41.0.1=py36_0
  - sip=4.19.8=py36hf484d3e_0
  - six=1.12.0=py36_0
  - sqlite=3.28.0=h7b6447c_0
  - tensorboard=1.13.1=py36hf484d3e_0
  - tensorflow=1.13.1=mkl_py36h27d456a_0
  - tensorflow-base=1.13.1=mkl_py36h7ce6ba3_0
  - tensorflow-estimator=1.13.0=py_0
  - termcolor=1.1.0=py36_1
  - tk=8.6.8=hbc83047_0
  - toolz=0.9.0=py36_0
  - torchvision=0.2.1=py36_1000
  - tornado=6.0.2=py36h7b6447c_0
  - wcwidth=0.1.7=py36_0
  - werkzeug=0.15.2=py_0
  - wheel=0.33.4=py36_0
  - xz=5.2.4=h14c3975_4
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.3.7=h0b5b093_0
  - pip:
      - playsound==1.2.2
prefix: .../anaconda3/envs/svglp

Regarding when skip is computed

In line 311 of train_svg_lp.py, what is the reasoning behind setting the condition as:

if opt.last_frame_skip or i < opt.n_past:	
    h, skip = h

instead of (the only change is the second < is swapped to <=):

if opt.last_frame_skip or i <= opt.n_past:	
    h, skip = h

Since h = encoder(x[i-1]), I believe the strict < will cause the skip features to be from the t = n_past - 1 frame instead of the t = n_past frame (where t is indexed from 1). Is this intended?

SVG-LP has eval() commented out, SVG-FP does not? For image generation.

Hi Dr. Denton,

Thank you very much for releasing your code. I have had fun using SVG to generate images.

Just one quick question, I noticed that the FP code has this:

svg/train_svg_fp.py

Lines 319 to 326 in 3f19f0b

# plot some stuff
frame_predictor.eval()
encoder.eval()
decoder.eval()
posterior.eval()
x = next(testing_batch_generator)
plot(x, epoch)
plot_rec(x, epoch)

whereas LP has this:

svg/train_svg_lp.py

Lines 359 to 368 in 3f19f0b

# plot some stuff
frame_predictor.eval()
#encoder.eval()
#decoder.eval()
posterior.eval()
prior.eval()
x = next(testing_batch_generator)
plot(x, epoch)
plot_rec(x, epoch)

Given that the encoder and decoder have batch normalization layers and are used during inference time, I believe LP should also have eval() enabled? Is my understanding correct?

Difference between the `plot` and `plot_rec` functions

Hi, thank you for making your work available. I have a few questions regarding the plot functions.

  • What is the main difference between the plot and plot_rec functions?

    • plot_rec conditions on groundtruth frames throughout the whole sequence, while plot only during the initial context phase (prior to opt.n_past) after which conditions on its on predictions. Is it correct?
    • Is there any special reason for plot_rec generating frames in range(1, opt.n_past+opt.n_future) a while plot in range(1, opt.n_eval)?
  • The plot function in file train_svg_fp.py does not have the following line

    posterior.hidden = posterior.init_hidden() 

    while all other plot functions (both svg_fb and svg_lp files) have it. Is it on purpose or a mistake?

Thanks again :)

How to make the code run on multiple GPUs

By default, the train code runs on only 1 GPU. How can make it run on multiple GPUs?

I tried wrapping the frame_predictor, posterior and prior modules with DataParallel. Then I got the error, No such attribute init_hidden for DataParallel. So, I changed frame_predictor.init_hidden() to frame_predictor.module.init_hidden(). Then I got the error something like, trying to do backward a second time. Graph not saved.

Is it possible to make the code run on multiple GPUs?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.