Giter Site home page Giter Site logo

char-rnn-pytorch's Introduction

Char-RNN-PyTorch

使用字符级别的RNN进行文本生成,使用PyTorch框架。Gluon实现

Requirements

PyTorch 0.3

MxTorch

tensorboardX

按照 pytorch 官网安装 pytorch,将 mxtorch 下载下来,放到根目录,安装 tensorboardX 实现 tensorboard 可视化

\Char-RNN-PyTorch
	\mxtorch
	\data
	\dataset
	\models
	config.py
	main.py

训练模型

所有的配置文件都放在 config.py 里面,通过下面的代码来训练模型

python main.py train

也可以在终端修改配置

python main.py train \
	--txt='./dataset/poetry.txt' \ # 训练用的txt文本
	--batch=128  \ # batch_size
	--max_epoch=300 \ 
	--len=30 \ # 输入RNN的序列长度
	--max_vocab=5000 \ # 最大的字符数量
	--embed_dim=512 \ # 词向量的维度
	--hidden_size=512 \ # 网络的输出维度
	--num_layers=2 \ # RNN的层数
	--dropout=0.5

如果希望使用训练好的网络进行文本生成,使用下面的代码

python main.py predict \
	--begin='天青色等烟雨' \ # 生成文本的开始,可以是一个字符,也可以一段话
	--predict_len=100 \ # 希望生成文本的长度
	--load_model='./checkpoints/CharRNN_best_model.pth' # 读取训练模型的位置

Result

如果使用古诗的数据集进行训练,可以得到下面的结果

天青色等烟雨翩 黄望堪魄弦夜 逐奏文明际天月辉 豪天明月天趣 天外何山重满 遥天明上天  心空游无拂天外空寂室叨

如果使用周杰伦的歌词作为训练集,可以得到下面的结果

这感觉得可能 我这玻童来 城堡药比生对这些年风天 脚剧飘逐在尘里里步的路 麦缘日下一经经 听觉得远回白择

char-rnn-pytorch's People

Contributors

l1aoxingyu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

char-rnn-pytorch's Issues

gru单元出错

运行到score, _ = model(x)时,出错,不知道是什么原因呢?TypeError: gru() received an invalid combination of arguments - got (Tensor, Tensor, list, float, int, float, bool, bool, bool), but expected one of:

  • (Tensor data, Tensor batch_sizes, Tensor hx, tuple of Tensors params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional)
    didn't match because some of the arguments have invalid types: (Tensor, Tensor, list, float, int, float, bool, bool, bool)
  • (Tensor input, Tensor hx, tuple of Tensors params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first)
    didn't match because some of the arguments have invalid types: (Tensor, Tensor, list, float, int, float, bool, bool, bool)

模型效果

作者您好,我这边运行了您提供的代码。训练模型的时候我使用了17MB的搜狗新闻数据,训练迭代了1000次。可是在生成文本的时候,效果并不是很好。本人也是神经网络初学者,我想请问下对于训练数据,有没有需要注意的地方呢?谢谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.