Comments (9)
As far as I understand with this packing, it will bring in as many data samples into the batch as it can < bs * max_len. The data collator should be pad to the longest in the batch. Could be the case I have a bug here somewhere. The multipack sampler comes from: https://github.com/imoneoi/multipack_sampler/blob/master/README.md#usage
from fine-tune-mistral.
Some quick tests for 1 epoch using 1k samples, multipack results in faster training compared to the torch sampler.
multipack sampler:
batch_size = 2
completion: 3 minutes 57 seconds
---
batch_size = 3
completion: OOM
torch sampler:
batch_size = 6
completion: 6 minutes 10 seconds
---
batch_size = 7
completion: OOM
from fine-tune-mistral.
I guess torch sampler here is just a random sampler. I can take another look at it then. My main concern was the high padded token ratio, maybe I can use the data you have in the repo to compute it again. ORCA paper actually mentions a different packing algorithm https://github.com/graphcore/tutorials/tree/sdk-release-2.1/blogs_code/packedBERT. I found it to be slow that's why I wanted to check this method.
https://github.com/graphcore/tutorials/blob/e9dbe4825f034a47871c4db0deb86d727cbd69b9/blogs_code/packedBERT/nnlshp.py#L51 is the main solver, if it can be speed up it can be used as well.
from fine-tune-mistral.
Thanks for the prompt response, that was my understanding of multipacker too. I will also check how it is leveraged in https://github.com/OpenAccess-AI-Collective/axolotl.
from fine-tune-mistral.
When performing multipacking, shouldn't the attention mask be adjusted as well ? Otherwise, there should be an information leak between two packed examples
from fine-tune-mistral.
When performing multipacking, shouldn't the attention mask be adjusted as well ? Otherwise, there should be an information leak between two packed examples
I would say currently itβs naive packing with eos token as separator. It works in my runs, as far as I can tell. It is also mentioned in this paper:
During training we always train on sequences of the full nctx = 2048 token context window, packing multiple
documents into a single sequence when documents are shorter than 2048, in order to increase computational efficiency.
Sequences with** multiple documents are not masked in any special way **but instead documents within a sequence
are delimited with a special end of text token, giving the language model the information necessary to infer that
context separated by the end of text token is unrelated
paper: https://arxiv.org/abs/2005.14165
from fine-tune-mistral.
You can also use the torch sampler, though when I did runs comparing the two I did not notice any significant difference evaluating the model. The difference was the training time
from fine-tune-mistral.
https://discord.com/channels/1104757954588196865/1104758010959634503/1159194895483941074
This is from axolotl discord. They concat all the batch inputs into a single tensor 1x(bsxseqlen) and use flash attn varlen scaled dot product with cumulative seqlens of each example in the tensor. In this case the naive approach might be better than random due to chance?
from fine-tune-mistral.
Here I tested the idea: https://gist.github.com/KeremTurgutlu/847dd84519e28df85e68f8d88dc29905
from fine-tune-mistral.
Related Issues (5)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fine-tune-mistral.