lucidrains / megabyte-pytorch Goto Github PK
View Code? Open in Web Editor NEWImplementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
License: MIT License
Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
License: MIT License
Thank you so much for taking the time to share your code with me! I appreciate your generosity in helping me better understand the paper.
I noticed that your code has a slightly unique approach to implementing the Patch Embedder in comparison to the original.
The original paper implementation of Patch Embedder uses separate global and local byte embeddings, followed by concatenation in the global model. Your implementation, however, involves using fixed bytes embedding and linear transformations between different transformers to construct the patch embedding.
self.patch_embedders = nn.ModuleList([nn.Sequential(
Rearrange('... r d -> ... (r d)'),
nn.LayerNorm(seq_len * dim_in),
nn.Linear(seq_len * dim_in, dim_out), # linear transformations here
nn.LayerNorm(dim_out)
) for dim_in, dim_out, seq_len in zip(dim[1:], dim[:-1], max_seq_len[1:])])
The input to this implementation is a token (0-16000). Isn't the whole point of the original paper that the input is a byte (0-255)? Am I missing something about the patch embedding?
Thank you for your code.
I tried to train MEGABYTE using default settings, but I faced the following error.
How can I fix it?
.../MEGABYTE_pytorch/attend.py", line 111, in flash_attn
out = F.scaled_dot_product_attention(
RuntimeError: No available kernel. Aborting execution.
how do we translate the various model size parameters provided in the paper to the max_seq_len and depth tuple arguments when constructing the model?
Hi there.
I’ve run the training code in this repository for 25k out of the 100k batches and achieved a validation loss of around 1.28, or perplexity of 3.59. After this, the training loss continues to drop but the validation loss either plateaus, or slowly starts going back up. I was curious if you also found the same (however, I stopped at 25k and restarted training. I reloaded the model and optimiser checkpoints but didn’t preserve train/val shuffling. Not sure if this confounds it either). Also I tried running the training on a H100 80GB VRAM with a batch size of 60 instead of 4 and found very slow convergence and an earlier plateau of the val loss (~2.5 ish). Do other hyperparameters need to be adjusted to scale training on larger devices? I originally tested on an RTX 3060Ti with 8GB VRAM.
Thanks in advance.
First of all, thank you for the author's contribution. Is this MEGABYTE only suitable for ASCII encoding? If you use Unicode, will it go wrong, easy to explode memory, then how to achieve character-level segmentation?
Quick question–what hardware did you use for training (+ controlling for computation time) for your original paper?
Do you think the hardware (degree of sharding, other optimizations) would effect the relative wall time a lot?
Flagging in case anyone else ran into this:
train.py errored for me initially on line 400 of megabyte.py:
start_tokens, logits = logits[:, 0, :1, :], logits[..., 1:, :]
I reshaped the start_tokens so they're shaped (4, 1, 256) instead of (4, 1, 5, 256) and the code runs fine.
I also saw this paper today. The main purpose of the paper is to get rid of the tokenizier, but in fact, the string is still divided into pieces, but a pre-decoder is added. Is there a problem with my understanding?
Is this model feasible and ready to start being implemented on a massive scale? Do you think it's useful? What do you think in general about Megabyte?
Hi, thanks to your advance code helping me to understand original paper!
Given only global_transformer and local_transformer, I found that your codes are different at some points:
let's set D_G=768, D_L=256, P=4, seq=40,batch_size=1,
h^global_out.shape=(1, 11, 768), choose top10 token, then h^global_out.shape=(1, 10, 768)
h^global_out = proj(h^global_out), then h^global_out.shape=(1, 10, 256)
h^local_in.shape=(10, 4, 256), concat with h^global_out, the final h^local_in has a shape of (10, 5, 256). at the end, drop the first token
in the original paper, it does
h^global_out.shape=(1, 11, 4 * 768), choose top10 token, then h^global_out.shape=(1, 10, 4 * 768)
h^global_out = proj(h^global_out), then h^global_out.shape=(1, 10, 4 * 256)
h^local_in = concat([pad, top3 token in h^local_in]) + h^global_out.reshape(10, 4, 256), which will drop last token in h_local_in.
Please take a look and see if I misunderstand something. Looking forward to your apply!
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
# expected
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
Is this a trick? or bug?
Hi there,
Megabyte paper uses bits-per-byte in Table 2 as their evaluation metric. It seems it has difference compared with byte level perplexity, since their number in arXiv and Code is < 1. So it should not be perplexity. This repo uses the cross-entropy loss and can easily calculate the byte level perplexity. May I ask how to compute bits-per-byte metric?
Thanks a lot.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.