Giter Site home page Giter Site logo

stanford-crfm / levanter Goto Github PK

View Code? Open in Web Editor NEW
467.0 13.0 68.0 8.99 MB

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax

Home Page: https://levanter.readthedocs.io/en/latest/

License: Apache License 2.0

Python 97.84% Shell 2.10% Dockerfile 0.06%

levanter's Introduction

Levanter

Build Status Documentation Status License PyPI

You could not prevent a thunderstorm, but you could use the electricity; you could not direct the wind, but you could trim your sail so as to propel your vessel as you pleased, no matter which way the wind blew.
— Cora L. V. Hatch

Levanter is a framework for training large language models (LLMs) and other foundation models that strives for legibility, scalability, and reproducibility:

  1. Legible: Levanter uses our named tensor library Haliax to write easy-to-follow, composable deep learning code, while still being high performance.
  2. Scalable: Levanter scales to large models, and to be able to train on a variety of hardware, including GPUs and TPUs.
  3. Reproducible: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.

We built Levanter with JAX, Equinox, and Haliax.

Documentation

Levanter's documentation is available at levanter.readthedocs.io. Haliax's documentation is available at haliax.readthedocs.io.

Features

  • Distributed Training: We support distributed training on TPUs (and soon, GPUs), including FSDP and tensor parallelism.
  • Compatibility: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via SafeTensors.
  • Performance: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText.
  • Cached On-Demand Data Preprocessing: We preprocess corpora online, but we cache the results of preprocessing so that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training.
  • Optimization: Levanter supports the new Sophia optimizer, which can be 2x as fast as Adam. We also support ses Optax for optimization with AdamW, etc.
  • Logging: Levanter supports a few different logging backends, including WandB and TensorBoard. (Adding a new logging backend is easy!) Levanter even exposes the ability to log inside of JAX jit-ted functions.
  • Reproducibility: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
  • Distributed Checkpointing: Distributed checkpointing is supported via Google's TensorStore library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.

Levanter was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team. You can also find us in the #levanter channel on the unofficial Jax LLM Discord

Getting Started

Here is a small set of examples to get you started. For more information about the various configuration options, please see the Getting Started guide or the In-Depth Configuration Guide. You can also use --help or poke around other configs to see all the options available to you.

Installing Levanter

After installing JAX with the appropriate configuration for your platform, you can install Levanter with:

pip install levanter

or using the latest version from GitHub:

git clone https://github.com/stanford-crfm/levanter.git
cd levanter
pip install -e .
wandb login  # optional, we use wandb for logging

If you're developing Haliax and Levanter at the same time, you can do something like.

git clone https://github.com/stanford-crfm/levanter.git
cd levanter
pip install -e .
cd ..
git clone https://github.com/stanford-crfm/haliax.git
cd haliax
pip install -e .
cd ../levanter

Please refer to the Installation Guide for more information on how to install Levanter.

If you're using a TPU, more complete documentation for setting that up is available here. GPU support is still in-progress; documentation is available here.

Training a GPT2-nano

As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset.

python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml

# alternatively, if you didn't use -e and are in a different directory
python -m levanter.main.train_lm --config_path gpt2_nano

This will train a GPT2-nano model on the WikiText-103 dataset.

Training a GPT2-small on your own data

You can also change the dataset by changing the dataset field in the config file. If your dataset is a Hugging Face dataset, you can use the data.id field to specify it:

python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.id openwebtext

# optionally, you may specify a tokenizer and/or a cache directory, which may be local or on gcs
python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.id openwebtext --data.tokenizer "EleutherAI/gpt-neox-20b" --data.cache_dir "gs://path/to/cache/dir"

If instead your data is a list of URLs, you can use the data.train_urls and data.validation_urls fields to specify them. Data URLS can be local files, gcs files, or http(s) URLs, or anything that fsspec supports. Levanter (really, fsspec) will automatically uncompress .gz and .zstd files, and probably other formats too.

python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.train_urls ["https://path/to/train/data_*.jsonl.gz"] --data.validation_urls ["https://path/to/val/data_*.jsonl.gz"]

Customizing a Config File

You can modify the config file to change the model, the dataset, the training parameters, and more. Here's the gpt2_small.yaml file:

data:
  train_urls:
      - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
  validation_urls:
      - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
  cache_dir: "gs://pubmed-mosaic/tokenized/openwebtext/"
model:
  gpt2:
    hidden_dim: 768
    num_heads: 12
    num_layers: 12
    seq_len: 1024
    gradient_checkpointing: true
    scale_attn_by_inverse_layer_idx: true
trainer:
  tracker:
    type: wandb
    project: "levanter"
    tags: [ "openwebtext", "gpt2"]

  mp: p=f32,c=bfloat16
  model_axis_size: 1
  per_device_parallelism: 4

  train_batch_size: 512
optimizer:
  learning_rate: 6E-4
  weight_decay: 0.1
  min_lr_ratio: 0.1

Other Architectures

Currently, we support the following architectures:

We plan to add more in the future.

Continued Pretraining with Llama 1 or Llama 2

Here's an example of how to continue pretraining a Llama 1 or Llama 2 model on the OpenWebText dataset:

python -m levanter.main.train_lm --config_path config/llama2_7b_continued.yaml

Distributed and Cloud Training

Training on a TPU Cloud VM

Please see the TPU Getting Started guide for more information on how to set up a TPU Cloud VM and run Levanter there.

Training with CUDA

Please see the CUDA Getting Started guide for more information on how to set up a CUDA environment and run Levanter there.

Contributing

We welcome contributions! Please see CONTRIBUTING.md for more information.

License

Levanter is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.

levanter's People

Contributors

anh-tong avatar aphoh avatar blahblahhhj avatar dependabot[bot] avatar dlwh avatar helw150 avatar ivan-zhou avatar mkly avatar patrick-kidger avatar raisin avatar rjpower avatar rtaori avatar vadam5 avatar versae avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

levanter's Issues

Task-Aware Training

Most of the models that people want to train are ultimately going to be used for instructions and few shot. It seems like it would be a good idea to do old school multitask training rather than pretraining followed by instruction tuning. We would like to experiment with this and maybe write a paper or something.

(Note these were @dlwh's notes to myself. feel free to follow or not)
Tasks:

  • #218
  • #219
  • DoReMi on the Pile (reproduce the paper, maybe with continued pretraining)
  • DoReMi on domain data + pile
  • #217
  • Finish #48 / #212
  • Task Data Loading
    • cache encoder-decoder-style datasets (inputs and targets)
    • write a dataset that converts to DecoderOnly on the fly
    • add task tokens/think about task tokens
  • Eval:
    • add non-log-loss metrics/tasks
    • #186
  • UL2R+DoReMi (DoReMi 1.5)
    • Paper? Alpha over [task,domain,language,...]. Alternatively, UL2R continuously
  • DoReMi 2
    • modify domain weights continuously during training (including resume from preemption --> serialization of data iterator states) #311
    • cluster documents to identify domains automatically, have 1000 domains, use the cluster centroid as features for the domains rather than domain index/identifier
  • Think about model export with prefixlm?

Framework for regression tests across multiple processes/hosts

I've had a few significant (and some minor) problems where something was fine on one machine, but ended up being subtly wrong on more than one machine. It'd be good to have some process in place for testing distributed code.

One option would be to manually spin up a jax.distributed instance locally in cpu only mode via fork (or maybe gpu?), though that may not work? I think you could write a fancy python decorator.

Otherwise, we'd need something to spin up two machines, run the tests in parallel, then shut down.

Things to test:

  • Checkpointing saving/loading (make sure the model is correct)
  • HF export/import
  • Sharded data loading (make sure we process the right number of examples, e.g.)
  • checks to ensure check_sharded_consistency (still) works, in service of the above

Make setup script work with private forks/local repos

ATM the tpu setup scripts assume we're cloning the main github of levanter. To facilitate development of other models etc., we should do something different.

Possibilities:

  1. scp the current levanter checkout
  2. git clone the current levanter checkout on the remotes and checking out HEAD (presumably a shallow clone?)
  3. git clone the upstream

for reproducibility, i would lean towards (2), maybe with a warning if HEAD is dirty.

cc @Ivan-Zhou

set up CI

just simple github actions for unit tests...

Finish/improve the overview doc

in docs/Overview.md there's a bunch of stuff explaining how Levanter thinks about FSDP and Model Parallelism (via Haliax). It's not quite done, but would be good to finish it up.

In particular:

  • make a basic training loop for overview
  • add pjit to tutorial training loop
  • add activation sharding to training tutorial
  • add jmp trainer example (mixed precision)
  • add haliax fsdp example

auto-discovery of Ray cluster on Slurm/TPU

in the case where there is no ray cluster already, we should auto-discover the cluster the same way that Jax does for Slurm and TPU. We should do this in cache_dataset and in on-the-fly

Part of #99

Make a ScanLayers Module

vmap/scan/fold scares people when it's used for anything other than the batch dim. It's unfortunately necessary for compilation times etc.

I found a vmap in the Gpt2Transformer code and it freaked me out – shouldn’t there be one big vmap around the use of the model in the code elsewhere or something? Now I’m worried I won’t put the vmaps where they should be.

We use vmap in GPT-2 initialization for so-called "scan layers" (q.v. https://docs.kidger.site/equinox/tricks/#improve-compilation-speed-with-scan-over-layers)

It'd be nice to hide this in a module: homongeneous stacks of layers are pretty common.

I have a 90% done branch for this. The big thing missing is just fixing GPT-2 torch export.

https://github.com/stanford-crfm/levanter/compare/scan_module?expand=1

Replace/improve babysit script to use some kind of cloud stuff

The tpu babysitting script is nice but it doesn't work well with our compute infrastructure because AFS tokens give out and the script stops working. It would be nice to replace with this something a bit more robust, e.g. some kind of cloud run function or cloud cron job thing.

It would also be nice if if this script would automatically set run ids and resume paths.

Actually implement ZeRO

I think it's not too hard. There are two options:

  • If it's allowed to have more than one mesh active at a time, have a mesh that is (num_replicas, num_shards) that is (potentially) different from the (data_axis, model_axis) split
  • if it's not, then just do fully sharded on the current mesh: what is currently the "model" axis would be ("data", "model").

Then, your resource map should be

{
"embed": (ResourceAxis.DATA, ResourceAxis.MODEL)
...
}

and then make the gradient and/or model states shard that way

Log metrics during preprocessing

Should log when using cache_dataset and when using on-the-fly.

In cache_dataset:

  • Ideally it should log to some kind of pbar. Should log number of docs processed, number of tokens created, number of shards finished

In on_the_fly:

  • should probably log to wandb and occasionally to stdout. Wandb logs should include docs processed, tokens created, shards finished
  • Can do pbar if we figure out interaction with the training pbar

Clarify/standardize naming convention for axes, and where to put them

I have not been clear about what the right naming convention is for axes in my code. I use SeqLen, Vocab, KeySeqLen, Embed, etc.

e.g.

It’s unclear what attributes (e.g., Vocab) are in fact, axes, not the objects the axes describe.

Relatedly:

Was unclear that if I wanted to make a new dimensionality/axis, I should change the @properties in the config that specify all the unique Axes.

Was unclear why I should sometimes pass around the config and why I should sometimes pass around specific Axes/config elements.

Sidi had a thing that turned ints into Axes automatically in configs. That was neat.

Jax CUDA install is a bit myserious for people

From JohnH:

After the hang
Error’d out with CuDNN error:

2023-03-12 21:47:12.489017: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.4.1 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.


Turns out I had installed JAX wrong for the CuDNN install on Jag35.
I only had CuDNN 8.4.1. (on CUDA 11.7, but also every other CUDA install)
I had installed JAX that expected 8.6
Why didn’t the JAX install tell me?
FWIW, JAX install instructions said it would fail with mismatched CuDNN, had the fix
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

GPUs

Do whatever work we need to do to get Levanter working on GPUs

  • Single gpu
  • multi gpu
  • multi node
  • #249
  • compare performance with PyTorch (Mistral or Composer or something?)
  • run profiler, see if there's anything obvious
  • there's a memory leak somewhere?

Should also investigate flash attention again as well #90

Decide on haliax.dot argument order

From JohnH:

Looking up hax.dot right now to get the right semantics/syntax for choosing named axes to multiply; wish it were linked in the overview doc.

I made it this way to be consistent with einsum, and because dot actually sneakily supports an arbitrary number of things to dot over (also like einsum)

On-The-Fly Caching (Tracking Issue)

Currently, our data preprocessing story is something like:

  • Get a corpus (either hf datasets or, better, jsonl files)
  • ideally shuffle the corpus (yourself)
  • Run a single machine process to pretokenize the data into shards

The last step can take many hours to days, and is a pretty annoying bottleneck that hits users right at the start of using levanter. No good.

We need an alternative. In this issue, I focus on the last step. Dealing with the second is covered in #34.

Roadmap

  • Ray Cluster for tokenizing across multiple machines, without support for simultaneous writing and reading #112
  • introduce a coordinator for simultaneous reads and writes #112
  • Add resumption for writing #112
  • auto-discovery of cluster on Slurm/TPU (can maybe hijack Jax's stuff) #118
  • log metrics during preprocessing #117
  • wire up on-the-fly tokenization during training #120
  • data loading from the global order with perfect reproducibility #119

Proposed Approach

(Design as of 2023-04-18)

Goals

We want to support the following:

  1. Deterministic batches, even for a changing number of readers (or writers). That is, for any cluster size
    during training, we want the same batches to be generated in the same order.
  2. Sharded reading and writing. We want to be able to read and write from multiple shards in parallel.
  3. Simultaneous reading and writing of shards. We want to be able to start training while we are still building the cache.
  4. Fast resumption without losing too much progress. This applies to both writing and reading the cache. That is, when we resume a training run, we want to finish producing the cache and also jump to the right place in the cache for reads.
  5. (eventually) shuffling/random access
  6. Takes advantage of the fact that we typically have idle, beefy CPUs on the machines where we're doing training
  7. We want to be able to build the cache offline too.
  8. We want to support batches that are composed of fragments of documents. In particular, we take a moving window of tokens from documents. This implies that the mapping from "documents" to "batches" is not 1:1, or easy to compute.
  9. (eventually) ≈random access to tokens and not docs
  10. can handle a variable number of examples being generated per input doc

We want to support the following use cases:

  1. We have a larger training dataset, and we want to draw samples from it more or less independently on a large number of machines.
    We don't really care about "epochs"/"passes", but we do want to be able to handle resumes and be deterministic. Ideally, each
    machine only reads from the chunks that it needs to read from.
  2. We have a smaller validation dataset, and we want to do a single pass over it. We don't care about resuming, and it's ok if
    we have to read the whole dataset on each machine.
  3. Like (1) but we want to jump around the dataset. We still care about resuming and determinism, but don't care about epochs.

We focus on (1) and (2) for now.

Some terminology

  • Shard: A shard is a list of raw documents that not been tokenized/preprocessed.
  • Chunk: A chunk is a list of processed documents that have been tokenized/preprocessed.
  • Reader: A reader is a process that reads from the cache. Typically there is one reader per machine.
  • Writer: A writer is a process that writes to the cache. Typically there is one writer per machine.
  • Global ordering: The global ordering is the ordering of chunks in the cache. This is the order in which
    documents are read by readers. The global ordering is defined with respect to an "idealized" number of readers R*. (See below.)
  • Processor or Tokenizer: A function that takes a raw document and returns a processed document.
  • Example is a single datum that is fed into the model. Examples are typically composed of fragments of documents.
    For example, we might take a moving window of tokens from the concatenation of a list of preprocessed documents.

We say there are K input shards, W writers, R readers. We assume K >= W (though typically K is not too large), and W ≈ R.
We produce N chunks. We also define an idealized number of readers R*, which defines the global ordering over the data.
Typically R* should be the maximum number of readers we expect to actually use.

Cache structure

We define a shard cache as a list of "chunks", where each chunk is a parquet file (plus metadata) with an equal
number of documents (except for the last chunks for each shard.)
Each chunk is a list of processed documents. Chunks are ordered round robin from the input shards, so that the c'th global chunk is the
c%K'th chunk of the c/K'th shard, so long as all shards have at least c/K chunks. (After that, we remove shards that
have been exhausted and continue round robin.)
We keep the following metadata:

  • For each shard, we keep a list of chunks written so far and whether or not we are done processing that shard.
  • For each chunk, we keep the number of documents, token counts/length of various fields, and the number of bytes.
    (This metadata can be used for seeking.)
  • For the cache overall, we keep the global ordering of chunks, the number of chunks, and the number of documents.

Chunk format

A Chunk is an Apache Parquet file with schema dependent on the task. For example, for language modeling, we might have
just a sequence of input_ids per document. We use Apache Parquet because it's compact and doesn't require us to know
much about the datatypes we're using.

Chunks also have metadata stored in a separate json file. This metadata includes the total number of documents in the
chunk, as well as token counts/lengths of various fields. This metadata is used for seeking.

Cache construction

We use Ray to manage the writers. Readers are managed by the main processes (though call into Ray to get the data).
At a high level, we create a writer process for each shard, which produce chunks one by one. T is a central
writer coordinator process that receives chunks from each shard and adds them to the global ordering round robin.
When chunks are added, we make them available to an actor that readers can access to get chunks.

Reproducible Sharded Reading for Training

We want to be able to read from the cache in a way that is deterministic and reproducible, even if the number of readers
changes. We also want readers to only read from the chunks that they need to read from.
We pretend the list of data is infinite by cycling. We cannot track epochs.

NB Our goal is a deterministic ordering over examples, and not merely chunks or even documents.

Given a list of chunks and the idealized number of readers R*, we define the global ordering over chunks as follows:
First define R* iterators over chunks, with chunk_iterators[r] being defined as loop(all_chunks)[r::R*].

Next, define a function mk_examples(chunk_iterator) that takes a list of iterators over chunks and returns
a list of examples. Define chunk_examples[r] = mk_examples(chunk_examples[r]).
This function depends on our sequence length, etc. Then the ordering over examples is:

chunk_examples[0][0], chunk_examples[1][0], ..., chunk_examples[R*-1][0], ..., chunk_examples[0][1], chunk_examples[1][1], ..., chunk_examples[R*-1][1], ...
that is, example[i] == chunk_examples[i % R*][i // R*]

If we have $R*$ readers, then each reader_iterator[r][j] == chunk_examples[r][j] == example[j * R* + r].
Moreover, if either R or R* is a multiple of the other, then we still get a nice property where
each reader reads from a strided slice of the chunk_iterators:

(Boring math)
If we have R readers, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*]
If we have R == n * R*, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[r % R*][(j * n * R* + r) // R*] == chunk_examples[r % R*][j * n + r // R*], so each reader reads from
a strided slice (specifically islice(..., r//R*, None, n))
If we have R* == n * R, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[R * (j % n) + r][(j * R + r) // R*] and so each reader reads from n different chunk_exampless.
so we round-robin over a slice of the chunk_exampless.

For other cases (R and R* don't divide each other), there's no simple relationship between the reader and chunk iterators
and you end up reading from everywhere, but that's ok.

Single-Pass Reading for Evaluation

When we want to do a single pass over the data, we don't cycle and we don't shuffle. We just read the data in order. Boring
and simple.

Resuming

We need to think about resuming in two cases: resuming writes and resuming reads.

Resuming Writes

Resuming writes is relatively easy, since we can just keep track of the number of chunks written for each shard and the
number of documents written for each chunk. Then you just skip to the appropriate document and start writing.

Resuming Reads

We want to understand how to seek to the b'th batch.

There are two cases of resuming we need to think about:

  1. The "easy" case where 1 example == 1 (preprocessed) document.
  2. The "hard" case where the mapping from examples to documents is not 1:1, but there is some easily computable relationship.

In the first case, each reader r reads documents[r::R]. The bth batch
is documents[b * batch_size:(b+1) * batch_size]. Assuming batch_size % R == 0, then for the b'th batch, reader r
needs to read documents[b * batch_size + r: (b+1) * batch_size + r: R] == docs(chunk_iterator[r])[b * batch_size // R:(b+1) * batch_size // R].
If we know how many documents are in each chunk, then we can seek to the right place in the chunk.

The second case is broadly similar. In particular, we consider the case where we take moving windows of concatenated documents.
If our metadata includes token counts, then we can skip chunks until we pass batch_size * tokens_per_example // R tokens.

Shuffling

A brief digression

Why do we shuffle in machine learning Shuffling reduces variance in the gradients. If we have batches
where every example is from the same document/domain, then the gradients for those batches will be correlated.

That said, in our setting where we use moving windows from documents, if we round-robin from chunks (which are produced
from different documents), and R* is roughly equal to the batch size, then we will read from a different chunk for every
example in a batch, which reduces correlation within a batch.

However, we still have (undesirable) correlation between batches: if we
read from chunks consecutively and our documents are long, then many examples in the next batch will be from the
same document as an example in the previous batch. Ideally this wouldn't happen. I'm not convinced that it matters
that much.

Proper shuffling is incompatible with streaming at a fundamental level. Our choices are something like:

  • Randomly shuffle before preprocessing. Makes life a bit less pleasant for people with a new dataset. Can't be changed after preprocessing. Doesn't solve the problem of correlated batches.
  • Reservoir sampling. Makes resumes hard, but is easy to implement.
  • "Epochal" reservoir sampling, where we periodically "flush" the reservoir. Resumes are easier because you can start from the latest "epoch"
  • No shuffling in the first pass, but shuffle in subsequent passes.
  • Shuffle within a range of chunks that grows as the run progresses.

My hunch is that we can skip this for now, and revisit if we find that it's a problem.

Preprocessing: Reproducible batch order even when number of workers changes

Part of #99

A bit tricky to pull off. The design doc has this:

Sharded Reading

We say there are K input shards, W writers, R readers. We assume K >= W (though typically K is not too large), and W ≈ R.
We produce N chunks. We also define an idealized number of readers R*, which defines the global ordering over the data.
Typically R* should be the maximum number of readers we expect to actually use.

We want to be able to read from the cache in a way that is deterministic and reproducible, even if the number of readers
changes. We also want readers to only read from the chunks that they need to read from.
We pretend the list of data is infinite by cycling. We cannot track epochs.

NB Our goal is a deterministic ordering over examples, and not merely chunks or even documents.

Given a list of chunks and the idealized number of readers R*, we define the global ordering over chunks as follows:
First define R* iterators over chunks, with chunk_iterators[r] being defined as loop(all_chunks)[r::R*].

Next, define a function mk_examples(chunk_iterator) that takes a list of iterators over chunks and returns
a list of examples. Define chunk_examples[r] = mk_examples(chunk_examples[r]).
This function depends on our sequence length, etc. Then the ordering over examples is:

chunk_examples[0][0], chunk_examples[1][0], ..., chunk_examples[R*-1][0], ..., chunk_examples[0][1], chunk_examples[1][1], ..., chunk_examples[R*-1][1], ...
that is, example[i] == chunk_examples[i % R*][i // R*]

If we have $R*$ readers, then each reader_iterator[r][j] == chunk_examples[r][j] == example[j * R* + r].
Moreover, if either R or R* is a multiple of the other, then we still get a nice property where
each reader reads from a strided slice of the chunk_iterators:

(Boring math)
If we have R readers, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*]
If we have R == n * R*, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[r % R*][(j * n * R* + r) // R*] == chunk_examples[r % R*][j * n + r // R*], so each reader reads from
a strided slice (specifically islice(..., r//R*, None, n))
If we have R* == n * R, then reader_iterator[r][j] == example[j * R + r] == chunk_examples[(j * R + r) % R*][(j * R + r) // R*] == chunk_examples[R * (j % n) + r][(j * R + r) // R*] and so each reader reads from n different chunk_exampless.
so we round-robin over a slice of the chunk_exampless.

For other cases (R and R* don't divide each other), there's no simple relationship between the reader and chunk iterators
and you end up reading from everywhere, but that's ok.

More thoughts

There are two cases to consider: one where 1 row = 1 example, and our current use case where the number of rows and examples can't be easily determined.

For the former case, you just read round robin from your assigned chunk slices. For the latter case, a reader needs to be careful to choose produce one example from each of its chunk streams round robin. The easiest way to do this is to inject a producer function into the chunk readers...

include directives in configs

Pyrallis doesn't natively support include/inherit directives. Should investigate other options or https://pypi.org/project/pyyaml-include/ which tweaks pyyaml to support include

My constraints on other options are: dataclass-first, arg parse, yaml, include, low ceremony, and more like a library and less like a framework (looking at you hydra)

Remove/hide `simplify_gdas`

simplify_gdas takes an array/pytree of arrays and turns it into a local array if it's not replicated. with jax.Array we can probably avoid, so if we can do #65 now then we should just do that

Wire up on-the-fly tokenization during training

Part of #99

With #112 we will have most of what we need to do on-the-fly tokenization, but it's not quite hooked up yet. We need to make the switch from TokenSeqDataset to whatever replace it, and not block training on the caching process.

Do we need to add upcasting and scaling of the attn layer?

So far on wikitext I haven't needed either. I think the former (upcasting) may be superfluous with bfloat16 on tpus, since it looks like it's already accumulated in fp32

From https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus:

bfloat16 is carefully used within systolic arrays to accelerate matrix multiplication operations on Cloud TPUs. More precisely, each multiply-accumulate operation in a matrix multiplication uses bfloat16 for the multiplication and 32-bit IEEE floating point for accumulation. In this post, we’ll examine the bfloat16 format in detail and discuss how Cloud TPUs use it transparently. Then we’ll take a detailed look at some of the benefits it provides, including higher performance, model portability, and better numerical stability for a wide variety of deep learning workloads.

Finish migrating to jax.Array

Jax deprecated/removed GlobalDeviceArray. https://jax.readthedocs.io/en/latest/jax_array_migration.html

I've already removed almost all references to it, but there are a couple.

  • Data Loading in GlobalBatchDataset
  • in partitioning, i think the GDA branches can be deleted
  • in jax_utils.py: global_key_array
  • tensorstore_serialization uses jax's gda_ser, which probably has a successor

We should make sure we maintain bitwise determinism. When I first did this, something changed and it made me uncomfortable, so I reverted.

Once this is done, we should also unpin/upgrade the jax dependency to >=0.4.7

Make data loading sufficiently random

I wrote a design doc here outlining desiderata and the current status, along with a potential plan for moving forward. (Very open to other designs!)

The basic issue is that if docs are super long or not randomized when they come in, performance seems to suffer substantially.

Sub issues assuming we go with the plan above:

  • implement a shuffle buffer
  • implement seek in IndexedDataset
  • implement JumpingDataset
  • figure out serialization of datasets

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.