Giter Site home page Giter Site logo

sunyu0410 / vanilla-transformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from arxyzan/vanilla-transformer

0.0 0.0 0.0 462.25 MB

A clean PyTorch implementation of the original Transformer model + A German -> English translation example

Python 100.00%

vanilla-transformer's Introduction

Vanilla Transformer (PyTorch)

My PyTorch implementation of the original Transformer model from the paper Attention Is All You Need inspired by all the codes and blogs I've read on this topic. There's nothing really special going on here except the fact that I tried to make it as barebone as possible. There is also a training code prepared for a simple German -> English translator written in pure PyTorch using Torchtext library.

My Inspirations

And probably a couple more which I don't remember ...

Prerequisites

  1. Install the required pip packages:
pip install -r requirements.txt
  1. Install spacy models :
python -m spacy download de_core_news_sm
python -m spacy download en_core_web_sm

Note: This code uses Torchtext's new API (v0.10.0+) and the dataset.py contains a custom text dataset class inherited from torch.utils.data.Dataset and is different from the classic methods using Field and BucketIterator (which are now moved to torchtext.legacy). Nevertheless torchtext library is still under heavy development so this code will probably break with the upcoming versions.

Train

In train.py we train a simple German -> English translation model on Multi30k dataset using the Transformer model. Make sure you configure the necessary paths for weights, logs, etc in config.py. Then you can simply run the file as below:

python train.py
Epoch: 1/10     100%|######################################################################| 227/227 [00:10<00:00, 21.61batch/s, loss=4.33]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 45.25batch/s, loss=3.13]
Saved Model at weights/1.pt

Epoch: 2/10     100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=2.82]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.68batch/s, loss=2.55]
Saved Model at weights/2.pt

Epoch: 3/10     100%|######################################################################| 227/227 [00:10<00:00, 22.56batch/s, loss=2.22]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.98batch/s, loss=2.22]
Saved Model at weights/3.pt

Epoch: 4/10     100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.83]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 52.20batch/s, loss=2.07]
Saved Model at weights/4.pt

Epoch: 5/10     100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.55]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 52.12batch/s, loss=2]   
Saved Model at weights/5.pt

Epoch: 6/10     100%|######################################################################| 227/227 [00:10<00:00, 22.25batch/s, loss=1.34]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.45batch/s, loss=1.95]
Saved Model at weights/6.pt

Epoch: 7/10     100%|######################################################################| 227/227 [00:10<00:00, 22.55batch/s, loss=1.17]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.34batch/s, loss=1.95]
Saved Model at weights/7.pt

Epoch: 8/10     100%|######################################################################| 227/227 [00:10<00:00, 22.46batch/s, loss=1.03]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.43batch/s, loss=1.96]
Saved Model at weights/8.pt

Epoch: 9/10     100%|######################################################################| 227/227 [00:10<00:00, 22.45batch/s, loss=0.91] 
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 52.84batch/s, loss=1.99]
Saved Model at weights/9.pt

Epoch: 10/10    100%|######################################################################| 227/227 [00:10<00:00, 22.50batch/s, loss=0.808]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.74batch/s, loss=2.01]
Saved Model at weights/10.pt

Inference

Given the sentence Eine Gruppe von Menschen steht vor einem Iglu as input in predict.py we get the following output which is pretty decent even though our dataset is somewhat naive & simple.

python predict.py
"Translation:  A group of people standing in front of a warehouse ."

TODO

  • predict.py for inference
  • Add pretrained weights
  • Visualize attentions
  • An in-depth notebook

vanilla-transformer's People

Contributors

arxyzan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.