Giter Site home page Giter Site logo

write-rnn-tensorflow's Introduction

Generative Handwriting Demo using TensorFlow

example

example

An attempt to implement the random handwriting generation portion of Alex Graves' paper.

See my blog post at blog.otoro.net for more information.

How to use

I tested the implementation on TensorFlow r0.11 and Pyton 3. I also used the following libraries to help:

svgwrite
IPython.display.SVG
IPython.display.display
xml.etree.ElementTree
argparse
pickle

Training

You will need permission from these wonderful people people to get the IAM On-Line Handwriting data. Unzip lineStrokes-all.tar.gz into the data subdirectory, so that you end up with data/lineStrokes/a01, data/lineStrokes/a02, etc. Afterwards, running python train.py will start the training process.

A number of flags can be set for training if you wish to experiment with the parameters. The default values are in train.py

--rnn_size RNN_SIZE             size of RNN hidden state
--num_layers NUM_LAYERS         number of layers in the RNN
--model MODEL                   rnn, gru, or lstm
--batch_size BATCH_SIZE         minibatch size
--seq_length SEQ_LENGTH         RNN sequence length
--num_epochs NUM_EPOCHS         number of epochs
--save_every SAVE_EVERY         save frequency
--grad_clip GRAD_CLIP           clip gradients at this value
--learning_rate LEARNING_RATE   learning rate
--decay_rate DECAY_RATE         decay rate for rmsprop
--num_mixture NUM_MIXTURE       number of gaussian mixtures
--data_scale DATA_SCALE         factor to scale raw data down by
--keep_prob KEEP_PROB           dropout keep probability

Generating a Handwriting Sample

I've included a pretrained model in /save so it should work out of the box. Running python sample.py --filename example_name --sample_length 1000 will generate 4 .svg files for each example, with 1000 points.

IPython interactive session.

If you wish to experiment with this code interactively, just run %run -i sample.py in an IPython console, and then the following code is an example on how to generate samples and show them inside IPython.

[strokes, params] = model.sample(sess, 800)
draw_strokes(strokes, factor=8, svg_filename = 'sample.normal.svg')
draw_strokes_random_color(strokes, factor=8, svg_filename = 'sample.color.svg')
draw_strokes_random_color(strokes, factor=8, per_stroke_mode = False, svg_filename = 'sample.multi_color.svg')
draw_strokes_eos_weighted(strokes, params, factor=8, svg_filename = 'sample.eos.svg')
draw_strokes_pdf(strokes, params, factor=8, svg_filename = 'sample.pdf.svg')

example1a example1b example1c example1d example1e

Have fun-

License

MIT

write-rnn-tensorflow's People

Contributors

bskaggs avatar dribnet avatar edwin-de-jong avatar grisaitis avatar hardmaru avatar memo avatar rajshah4 avatar sygi avatar yaylinda avatar zhaoyu611 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

write-rnn-tensorflow's Issues

Generated samples look bad, but training seems ok

I'm getting crappy-looking samples from networks that seem to train properly. Loss converges nicely (see last screenshot), but samples don't compare with @hardmaru's results.

Here are some examples... Any ideas for what could be causing this?

image
image
image
image

This is after training for 10920 iterations (30 * 364) with python train.py --rnn_size 400 --num_layers 3. Default args produce similar results.

Training and val loss look fine:

Has anyone been able to train and produce great results with recent tensorflow? I wonder if some defaults have changed, or interface changes are resulting in some bad set up for the loss function.

I think the loss function is the issue because loss values "look good" but actual results look bad. I think the loss function is optimizing for the wrong thing, basically.

Any ideas appreciated. My next steps are to review how the loss is defined and maybe compare it with other implementations (https://greydanus.github.io/2016/08/21/handwriting/, https://github.com/snowkylin/rnn-handwriting-generation)

Differences from Paper

Hey! I hope you're having a good day. My friend and I actually implemented this as well for a final project. We were vastly impressed by some of your results and your style (your code is really clean). However, I do have a few questions. These are a little bit more high level than normal Github questions, but I think they're still pertinent.

  1. In the Graves paper, he increases the dimensionality of the representation of the input through a series of stacked LSTMs with skip connections. So for example, if we have m = 3 as the depth of the stacked connection, and our input dimension is n = 3 (for x,y,eos) then we would have an ending dimension in $$\mathbb{R}^{18}$$. We can see what this model looks like in Figure 1 of the Graves paper.

However, it seems like in your code you do not do this? It doesn't appear that you actually create this cascade... but your results look fantastic. You simply have a line that says:

outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm')

Just to be clear, there is no dimensionality increase here? But R3 is still a large enough space for your inputs to be represented and for your model to be so robustly trained?

It seems like you specify your hidden size for the RNN as 256 as the default. How is this possible? Doesn't it need to be 3 so that it corresponds with the input being in $$R^3$$?

  1. Do you penalize your model for starting a new stroke? For example, it does not appear that you reset the internal LSTM parameters at any point if you're starting a new stroke (which Graves notes we want to do). Do you ignore this and still get such high quality results?

  2. You said you trained over your ENTIRE training dataset (all 11,035 strokes) for only HALF of a day WITHOUT a GPU enabled macbook and you were able to generate such clear and realistic handwriting? I just want to be clear because either a higher degree of complexity in our model (because of the stacked LSTMs) are causing training to be slower, or we're not sampling correctly from our model.

I'd love to hear some responses back! Thank you so much for the write up. It is lovely and your knowledge of the SVG package is super impressive. I'd never even heard of it before your writeup.

Thanks for everything!

NotFoundError, Tensor name [...] not found in checkpoint files save/model.ckpt-11000

First, thanks a lot for this very cool code!

Running the pretrained model with the suggested command: python sample.py --filename example_name --sample_length 1000

Produces this error:

WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x7f7a3c53e6a0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.
WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x7f7a3b5250f0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.
WARNING:tensorflow:From /home/javier/repos/write-rnn-tensorflow/model.py:137: calling reduce_max (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
WARNING:tensorflow:From /home/javier/repos/write-rnn-tensorflow/model.py:141: calling reduce_sum (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
2018-03-15 03:59:51.217146: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2
loading model: save/model.ckpt-11000
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1329, in _run_fn
status, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in exit
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/bias/Adam" not found in checkpoint files save/model.ckpt-11000
[[Node: save/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_4/tensor_names, save/RestoreV2_4/shape_and_slices)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "sample.py", line 42, in
saver.restore(sess, ckpt.model_checkpoint_path)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1686, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1128, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1344, in _do_run
options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1363, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/bias/Adam" not found in checkpoint files save/model.ckpt-11000
[[Node: save/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_4/tensor_names, save/RestoreV2_4/shape_and_slices)]]

Caused by op 'save/RestoreV2_4', defined at:
File "sample.py", line 37, in
saver = tf.train.Saver()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1239, in init
self.build()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1248, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1284, in _build
build_save=build_save, build_restore=build_restore)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 765, in _build_internal
restore_sequentially, reshape)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 428, in _AddRestoreOps
tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 268, in restore_op
[spec.tensor.dtype])[0])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_io_ops.py", line 1031, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3160, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1625, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

NotFoundError (see above for traceback): Tensor name "rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/bias/Adam" not found in checkpoint files save/model.ckpt-11000
[[Node: save/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_4/tensor_names, save/RestoreV2_4/shape_and_slices)]]

Please advise! Thank you in advance.

Which version of TF do I need?

I installed TensorFlow with version r0.11. But there appears many mistakes while runing. Most of them are attibute errors. When I change the attribute names to old ones, it ran well. But I have to fix them one by one.

Do I have to update version 1.0 to address them problem? Thx in advance.

Unsupported pickle protocol: 3

Hey,

I'm new to Python as well as machine learning and couldn't find a way to solve this problem. When I run sample.py as described in the Readme, Terminal outputs this error notice:

Traceback (most recent call last):
File "sample.py", line 36, in , saved_args = pickle.load(f)
...
ValueError: unsupported pickle protocol: 3

How can I solve this?
Thank you very much for your work and help!

Better explenation how to run.

As I followed the Readme I get an error "No such file or directory './data/strokes_training_data.cpkl'" How can I generate these training Data?

Can the loglikelhood cost go negative ?

Hi,

I have modified the code to handle another type of data where the output gaussians are uni-variate. However, I noticed while training the model that the log likelihood cost goes negative and keeps decreasing ? I wonder if this negative likelihood is sensible , as while your training your model too sometimes I notice negative cost values reported for one or two batches ? Or it means that I have potentially made some bugs while modifying the code.

Thanks,
Moustafa

Custom Dataset, train.py not creating 'model.ckpt-0'

caveat: I'm very noob with all things machine learning as well as python

I've got the sample data working just fine e.g. clone repo and run python sample.py --filename example_name --sample_length 1000

Now, I'm trying to use my own SVG dataset (placed within ./data and when I run python train.py after about 2–3 minutes it generates a strokes_training_data.cpkl in the ./data folder and config.pkl in the ./save folder.

I've ran other examples of neural training and typically it takes much longer to train a dataset and there are also a series of model checkpoint (?) files e.g. model.ckpt-0 generated per epoch (?). Is there a reason why I'm not getting the same results?

Unable to experiment with sample.py

When I run %run -i sample.py in IPython console or just python sample.py, I get the following error:

In [1]: %run -i sample.py
WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x7f093d719390>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/Downloads/write-rnn-tensorflow/sample.py in ()
34 saved_args = pickle.load(f)
35
---> 36 model = Model(saved_args, True)
37 sess = tf.InteractiveSession()
38 #saver = tf.train.Saver(tf.all_variables())

~/Downloads/write-rnn-tensorflow/model.py in init(self, args, infer)
46 # inputs = tf.split(axis=1, num_or_size_splits=args.seq_length, value=self.input_data)
47 # inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
---> 48 inputs = tf.unpack(self.input_data, axis=1)
49
50 outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm')

AttributeError: module 'tensorflow' has no attribute 'unpack'

Does this has something to do with the latest version of Tensorflow? I am using v1.3.0 if that helps.

ValueError: Trying to share variable rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (512, 1024) and found shape (259, 1024).

When I run train.py, got this errors. I think is kind of version problems, my tf version is 1.3

File "/home/lxt/tf_project/HyperNetwork/write-rnn-tensorflow/model.py", line 50, in init
outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm')
File "/home/lxt/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 152, in rnn_decoder
output, state = cell(inp, state)

EOS Data

Hello,

Just wanted to clarify something. The data is arranged in x, y, eos - but you have the following line of code:
z_eos = z[:, 0:1]

This would say grab the first column. However, shouldn't it be something like:
z_eos = z[:,2]

Do you somehow rearrange the data around?

Training error

When i tried to run train.py file..... I got the following error

"IOError: [Errno 2] No such file or directory: './data/strokes_training_data.cpkl'

Please provide some suggestions

Got NotImplementedError: Negative indices are currently unsupported

write-rnn-tensorflow/model.py", line 50, in __init__
    self.final_state = states[-1]
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 124, in _SliceHelper
    raise NotImplementedError("Negative indices are currently unsupported")
NotImplementedError: Negative indices are currently unsupported

It seems the tensorflow lib has been changed !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.