Giter Site home page Giter Site logo

lucidrains / megabyte-pytorch Goto Github PK

View Code? Open in Web Editor NEW
592.0 10.0 49.0 35.3 MB

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence deep-learning learned-tokenization attention-mechanisms long-context transformers

megabyte-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

megabyte-pytorch's Issues

the patch embbeder implementations are different from the original paper

Thank you so much for taking the time to share your code with me! I appreciate your generosity in helping me better understand the paper.

I noticed that your code has a slightly unique approach to implementing the Patch Embedder in comparison to the original.

The original paper implementation of Patch Embedder uses separate global and local byte embeddings, followed by concatenation in the global model. Your implementation, however, involves using fixed bytes embedding and linear transformations between different transformers to construct the patch embedding.

the paper

image

your code

self.patch_embedders = nn.ModuleList([nn.Sequential(
    Rearrange('... r d -> ... (r d)'),
    nn.LayerNorm(seq_len * dim_in),
    nn.Linear(seq_len * dim_in, dim_out),  # linear transformations here
    nn.LayerNorm(dim_out)
) for dim_in, dim_out, seq_len in zip(dim[1:], dim[:-1], max_seq_len[1:])])

Why does it expect tokens?

The input to this implementation is a token (0-16000). Isn't the whole point of the original paper that the input is a byte (0-255)? Am I missing something about the patch embedding?

No available kernel error

Thank you for your code.

I tried to train MEGABYTE using default settings, but I faced the following error.
How can I fix it?

.../MEGABYTE_pytorch/attend.py", line 111, in flash_attn
    out = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution.

Training Results and Scaling

Hi there.

I’ve run the training code in this repository for 25k out of the 100k batches and achieved a validation loss of around 1.28, or perplexity of 3.59. After this, the training loss continues to drop but the validation loss either plateaus, or slowly starts going back up. I was curious if you also found the same (however, I stopped at 25k and restarted training. I reloaded the model and optimiser checkpoints but didn’t preserve train/val shuffling. Not sure if this confounds it either). Also I tried running the training on a H100 80GB VRAM with a batch size of 60 instead of 4 and found very slow convergence and an earlier plateau of the val loss (~2.5 ish). Do other hyperparameters need to be adjusted to scale training on larger devices? I originally tested on an RTX 3060Ti with 8GB VRAM.

Thanks in advance.

Some question about the MEGABYTE

First of all, thank you for the author's contribution. Is this MEGABYTE only suitable for ASCII encoding? If you use Unicode, will it go wrong, easy to explode memory, then how to achieve character-level segmentation?

GPU used for original paper experiments

Quick question–what hardware did you use for training (+ controlling for computation time) for your original paper?

Do you think the hardware (degree of sharding, other optimizations) would effect the relative wall time a lot?

Minor shape error

Flagging in case anyone else ran into this:
train.py errored for me initially on line 400 of megabyte.py:
start_tokens, logits = logits[:, 0, :1, :], logits[..., 1:, :]
I reshaped the start_tokens so they're shaped (4, 1, 256) instead of (4, 1, 5, 256) and the code runs fine.

the string is still divided into pieces

I also saw this paper today. The main purpose of the paper is to get rid of the tokenizier, but in fact, the string is still divided into pieces, but a pre-decoder is added. Is there a problem with my understanding?

some implementations are different from the original paper

Hi, thanks to your advance code helping me to understand original paper!

Given only global_transformer and local_transformer, I found that your codes are different at some points:

let's set D_G=768, D_L=256, P=4, seq=40,batch_size=1,

  1. then h^global_out shape is (1, 1+10, 768) in your code, but it should be (1, 1+10, 4 * 768) in paper (1+10 means first pad token + following patch tokens). There are two token_emb func in original paper for h^global_in and h^local_in, your code only has a token_emb for h^local_in.
  2. the proj between global and local transformer projects dim 768 to 256, your code is doing this:
  • h^global_out.shape=(1, 11, 768), choose top10 token, then h^global_out.shape=(1, 10, 768)

  • h^global_out = proj(h^global_out), then h^global_out.shape=(1, 10, 256)

  • h^local_in.shape=(10, 4, 256), concat with h^global_out, the final h^local_in has a shape of (10, 5, 256). at the end, drop the first token

    in the original paper, it does

  • h^global_out.shape=(1, 11, 4 * 768), choose top10 token, then h^global_out.shape=(1, 10, 4 * 768)

  • h^global_out = proj(h^global_out), then h^global_out.shape=(1, 10, 4 * 256)

  • h^local_in = concat([pad, top3 token in h^local_in]) + h^global_out.reshape(10, 4, 256), which will drop last token in h_local_in.

image

Please take a look and see if I misunderstand something. Looking forward to your apply!

Evaluation metric bits-per-byte

Hi there,

Megabyte paper uses bits-per-byte in Table 2 as their evaluation metric. It seems it has difference compared with byte level perplexity, since their number in arXiv and Code is < 1. So it should not be perplexity. This repo uses the cross-entropy loss and can easily calculate the byte level perplexity. May I ask how to compute bits-per-byte metric?

Thanks a lot.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.