Giter Site home page Giter Site logo

decoder-from-scratch's Introduction

decoder-from-scratch

Implementing a decoder-only model like GPT2 from scratch using PyTorch and tiktoken.

Overview

Code to build a decoder-only transformers model akin to OpenAI's GPT2 using just PyTorch. My criteria for success was to be able to train the model on actual text and have it generate a set of phrases given an input sequence.

I used the Jules Verne novel 'The Mysterious Island' for my training corpus.

Tokenization

To focus only on programming the transformer architecture, I opted to use pre-trained tokenizers, an idea I stole from Andrej Karpathy.

Sequence structuring

Since I'm training on just one piece of text, I decided to break up the text into sliding window sequences of 40 tokens. Since I discarded the last incomplete sequence, I didn't actually have any stop words.

As for the EOS (end of sequence token) token, when someone trains a model in a commercial setting, they will of course have more than 1 document, and EOS tokens would be relevant. In my case, however, there is just 1 large sequence, so ignoring them is fine.

The implication is that for inference, we do not expect the model to generate a token for when the sequence should end.

Model parameters

I used a batch size of 50, a feed-forward network dimension of 20 with 10 decoder blocks. I started using a small d_model (aka hidden size aka embedding size) - around 10. However, the loss across epochs during training was not improving. Considering that GPT2 was trained with a d_model of 768, I increased it to 500 and 10 attention heads.

Infastructure used

I attempted to train locally on my CPU but it was taking a very long time (~2 hours / epoch). I switched to a Colab Pro A100 GPU, but quickly exceeded the allocated quotas, and Google blocked access to GPUs (except for what seemed to be 20 minutes of use per day). So I finally switched to a NVidia Tesla T4 on AzureML, which cost me about $4.80 to train for 10 epochs.

Inference

I used an autoregressive approach for predicting on an input. The decoder architecture is flexible to the number of tokens in the input sequence for prediction. We just specify how many words we want to predict, add the predicted next word to the input sequence and predict again.

The results on some starter sequences with greedy search were repetitive, so I also implemented a multinomial search strategy. This improved, but unfortunately the result is still incoherent.

Given that the loss kept decreasing all the way up to epoch 10, my conclusion is that training the model further would likely improve the predictions.

decoder-from-scratch's People

Contributors

matsuobasho avatar

Stargazers

 avatar sumit pawar avatar Danny Hernandez avatar martintmv avatar Prithvi avatar

Watchers

 avatar

Forkers

abhiwins

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.