Giter Site home page Giter Site logo

wm-csu / blockretrans-torch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lucidrains/block-recurrent-transformer-pytorch

0.0 0.0 0.0 35 MB

Implementation of Block Recurrent Transformer - Pytorch. The paper is https://arxiv.org/abs/2203.07852

Home Page: https://github.com/google-research/meliad

License: MIT License

Python 100.00%

blockretrans-torch's Introduction

Block Recurrent Transformer - Pytorch

Implementation of Block Recurrent Transformer - Pytorch. The highlight of the paper is its reported ability to remember something up to 60k tokens ago.

This design is SOTA for recurrent transformers line of research, afaict.

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

Install

$ pip install block-recurrent-transformer

Usage

import torch
from block_recurrent_transformer_pytorch import BlockRecurrentTransformer

model = BlockRecurrentTransformer(
    num_tokens = 20000,             # vocab size
    dim = 512,                      # model dimensions
    depth = 6,                      # depth
    dim_head = 64,                  # attention head dimensions
    heads = 8,                      # number of attention heads
    max_seq_len = 1024,             # the total receptive field of the transformer, in the paper this was 2 * block size
    block_width = 512,              # block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)
    xl_memories_layers = (5, 6),    # which layers to use xl memories. very old deepmind papers have shown you only need the last penultimate layers to have cached key values to see majority of benefit
    num_state_vectors = 512,        # number of state vectors, i believe this was a single block size in the paper, but can be any amount
    recurrent_layers = (4,),        # where to place the recurrent layer(s) for states with fixed simple gating
    enhanced_recurrence = True      # enhanced recurrence from ernie-doc paper, i have seen it to work well on my local machine
)

seq = torch.randint(0, 2000, (1, 1024))

out, mems1, states1 = model(seq)
out, mems2, states2 = model(seq, xl_memories = mems1, states = states1)
out, mems3, states3 = model(seq, xl_memories = mems2, states = states2)

Test on Enwik8

First pip install -r requirements.txt, then

$ python train.py

Todo

  • use dynamic positional bias

  • add enhanced recurrence

  • setup local attention blocks, as in the paper

  • wrapper transformer class for training

  • take care of generation with recurrence in RecurrentTrainWrapper

  • add ability to dropout to entire memories and states during each segment step during trainng

  • test full system on enwik8 locally and ablate states and memories and see effects first hand

  • make sure attention allow for single head key / values too

  • run a few experiments of fixed gating in regular transformers - does not work

  • revisit memformer

  • integrate flash attention

Citations

@article{Hutchins2022BlockRecurrentT,
    title   = {Block-Recurrent Transformers},
    author  = {DeLesley S. Hutchins and Imanol Schlag and Yuhuai Wu and Ethan Dyer and Behnam Neyshabur},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.07852}
}
@article{Ding2021ERNIEDocAR,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2012.15688}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}

Memory is Attention through Time - Alex Graves

blockretrans-torch's People

Contributors

lucidrains avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.