Giter Site home page Giter Site logo

prakhar21 / textaugmentation-gpt2 Goto Github PK

View Code? Open in Web Editor NEW
187.0 7.0 42.0 671 KB

Fine-tuned pre-trained GPT2 for custom topic specific text generation. Such system can be used for Text Augmentation.

License: MIT License

Python 100.00%
gpt-2 nlp-machine-learning transformer-architecture text-augmentation natural-language-processing natural-language-generation textclassification

textaugmentation-gpt2's Introduction

TextAugmentation-GPT2

GPT2 model size representation Fine-tuned pre-trained GPT2 for topic specific text generation. Such system can be used for Text Augmentation.

Getting Started

  1. git clone https://github.com/prakhar21/TextAugmentation-GPT2.git
  2. Move your data to data/ dir.

* Please refer to data/SMSSpamCollection to get the idea of file format.

Tuning for own Corpus

  1. Assuming are done with Point 2 under Getting Started
2. Run python3 train.py --data_file <filename> --epoch <number_of_epochs> --warmup <warmup_steps> --model_name <model_name> --max_len <max_seq_length> --learning_rate <learning_rate> --batch <batch_size>

Generating Text

1. python3 generate.py --model_name <model_name> --sentences <number_of_sentences> --label <class_of_training_data>

* It is recommended that you tune the parameters for your task. Not doing so may result in choosing default parameters and eventually giving sub-optimal performace.

Quick Testing

I had fine-tuned the model on SPAM/HAM dataset. You can download it from here and follow the steps mentioned under Generation Text section.

Sample Results

SPAM: You have 2 new messages. Please call 08719121161 now. £3.50. Limited time offer. Call 090516284580.<|endoftext|>
SPAM: Want to buy a car or just a drink? This week only 800p/text betta...<|endoftext|>
SPAM: FREE Call Todays top players, the No1 players and their opponents and get their opinions on www.todaysplay.co.uk Todays Top Club players are in the draw for a chance to be awarded the £1000 prize. TodaysClub.com<|endoftext|>
SPAM: you have been awarded a £2000 cash prize. call 090663644177 or call 090530663647<|endoftext|>

HAM: Do you remember me?<|endoftext|>
HAM: I don't think so. You got anything else?<|endoftext|>
HAM: Ugh I don't want to go to school.. Cuz I can't go to exam..<|endoftext|>
HAM: K.,k:)where is my laptop?<|endoftext|>

Important Points to Note

  • Top-k and Top-p Sampling (Variant of Nucleus Sampling) has been used while decoding the sequence word-by-word. You can read more about it here

Note: First time you run, it will take considerable amount of time because of the following reasons -

  1. Downloads pre-trained gpt2-medium model (Depends on your Network Speed)
  2. Fine-tunes the gpt2 with your dataset (Depends on size of the data, Epochs, Hyperparameters, etc)

All the experiments were done on IntelDevCloud Machines

textaugmentation-gpt2's People

Contributors

prakhar21 avatar sayantanbasu05 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

textaugmentation-gpt2's Issues

Facing issue in generate.py


RuntimeError Traceback (most recent call last)
in
93 TOKENIZER, MODEL = load_models(MODEL_NAME)
94
---> 95 generate(TOKENIZER, MODEL, SENTENCES, LABEL, DEVICE)

in generate(tokenizer, model, sentences, label, device)
47
48 next_token_id = choose_from_top_k_top_n(softmax_logits.to('cpu').numpy()) #top-k-top-n sampling
---> 49 cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1)
50
51 if next_token_id in tokenizer.encode('<|endoftext|>'):

RuntimeError: All input tensors must be on the same device. Received cpu and cuda:0

Removed

Apologies. Wrong post and I've closed the issue.

I am facing issues in train.py and also in generate.py

Can you please reply back so that I can discuss the problems I am facing.

train.py --> model = model.to(device) --> Error: name model is not defined
generated.py --> for pre-trained model I am getting issues for mismatch and checkpoints

Regards

n is not been used

I am afraid that the value n in 'generate.py' line 44 and line 46 is not been used. Could you help? Thanks.

数据集

请问作者方便把数据集发给我吗。麻烦啦

Multi label

Thank you for your sharing.I've been studying your code.
Now I want it use on multi-label,but I found your code and examples is on a single label.
If I want to use it on multi-label,how should I modify it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.