Giter Site home page Giter Site logo

druidowm / levanter Goto Github PK

View Code? Open in Web Editor NEW

This project forked from stanford-crfm/levanter

0.0 0.0 0.0 6.15 MB

Legibile, Scalable, Reproducible Foundation Models with Named Tensors and Jax

License: Apache License 2.0

Shell 3.43% Python 96.57%

levanter's Introduction

Levanter

You could not prevent a thunderstorm, but you could use the electricity; you could not direct the wind, but you could trim your sail so as to propel your vessel as you pleased, no matter which way the wind blew.
โ€” Cora L. V. Hatch

Levanter is a framework for training large language models (LLMs) and other foundation models that strives for legibility, scalability, and reproducibility:

  1. Legible: Levanter uses our named tensor library Haliax to write easy-to-follow, composable deep learning code, while still being high performance.
  2. Scalable: Levanter scales to large models, and to be able to train on a variety of hardware, including GPUs and TPUs.
  3. Reproducible: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.

We built Levanter with JAX, Equinox, and Haliax.

Features

  • Distributed Training: We support distributed training on TPUs (and soon, GPUs), including FSDP and tensor parallelism.
  • Cached On-Demand Data Preprocessing: We preprocess corpora online, but we cache the results of preprocessing so that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training.
  • Online Visualization: Levanter also provides a feature for visualizing the probability of each token in the validation set during training. This is useful for debugging and for understanding how the model is learning.
  • Export: We support exporting models to the Hugging Face Hub, with export compatible with Pytorch and Transformers via SafeTensors.
  • Logging: Logging is done with WandB, complete with a fancy online visualization of the validation set during training.
  • Distributed Checkpointing: Distributed checkpointing is supported via Google's TensorStore library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.
  • Optimization: Levanter uses Optax for optimization, though our new optimizer, Sofia, is coming to Levanter soon!
  • Stability: The GPT-2 implementation uses the Mistral stability trick to improve stability during training.

Levanter was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team. (We're hiring!) You can also find us in the #levanter channel on the unofficial Jax LLM Discord

Getting Started

Here is a small set of examples to get you started. For more information about the various configuration options, please see the Training Getting Started guide. You can also use --help or poke around other configs to see all the options available to you.

Installing Levanter

After installing JAX with the appropriate configuration for your platform, you can install Levanter with:

pip install levanter

or using the latest version from GitHub:

git clone https://github.com/stanford-crfm/levanter.git
cd levanter
pip install -e .
wandb login  # optional, we use wandb for logging

If you're developing Haliax and Levanter at the same time, you can do something like.

git clone https://github.com/stanford-crfm/levanter.git
cd levanter
pip install -e .
cd ..
git clone https://github.com/stanford-crfm/haliax.git
cd haliax
pip install -e .
cd ../levanter

Please refer to the Installation Guide for more information on how to install Levanter.

If you're using a TPU, more complete documentation for setting that up is available here. GPU support is still in-progress; documentation is available here.

Training a GPT2-nano

As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset.

python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml

# alternatively, if you didn't use -e and are in a different directory
python -m levanter.main.train_lm --config_path gpt2_nano

This will train a GPT2-nano model on the WikiText-103 dataset.

Training a GPT2-small on your own data

You can also change the dataset by changing the dataset field in the config file. If your dataset is a Hugging Face dataset, you can use the data.id field to specify it:

python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.id openwebtext

# optionally, you may specify a tokenizer and/or a cache directory, which may be local or on gcs
python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.id openwebtext --data.tokenizer "EleutherAI/gpt-neox-20b" --data.cache_dir "gs://path/to/cache/dir"

If instead your data is a list of URLs, you can use the data.train_urls and data.validation_urls fields to specify them. Data URLS can be local files, gcs files, or http(s) URLs, or anything that fsspec supports. Levanter (really, fsspec) will automatically uncompress .gz and .zstd files, and probably other formats too.

python -m levanter.main.train_lm --config_path config/gpt2_small.yaml --data.train_urls ["https://path/to/train/data_*.jsonl.gz"] --data.validation_urls ["https://path/to/val/data_*.jsonl.gz"]

Customizing a Config File

You can modify the config file to change the model, the dataset, the training parameters, and more. Here's the gpt2_small.yaml file:

data:
  train_urls:
      - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
  validation_urls:
      - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
  cache_dir: "gs://pubmed-mosaic/tokenized/openwebtext/"
model:
  gpt2:
    hidden_dim: 768
    num_heads: 12
    num_layers: 12
    seq_len: 1024
    gradient_checkpointing: true
    scale_attn_by_inverse_layer_idx: true
trainer:
  wandb:
    project: "levanter"
    tags: [ "openwebtext", "gpt2"]

  mp: p=f32,c=bfloat16
  model_axis_size: 1
  per_device_parallelism: 4

  train_batch_size: 512
optimizer:
  learning_rate: 6E-4
  weight_decay: 0.1
  min_lr_ratio: 0.1

Other Architectures

Currently, we support GPT-2, Backpacks and MosaicML's MPT architectures. We plan to add more in the future.

A Tiny Backpack Model

python -m levanter.main.train_lm --config_path config/backpack_nano.yaml

Continued Pretraining with MPT

python -m levanter.main.train_lm --config_path config/mpt_7b_continued.yaml

Distributed and Cloud Training

Training on a TPU Cloud VM

Please see the TPU Getting Started guide for more information on how to set up a TPU Cloud VM and run Levanter there.

Training with CUDA

Please see the CUDA Getting Started guide for more information on how to set up a CUDA environment and run Levanter there.

Contributing

We welcome contributions! Please see CONTRIBUTING.md for more information.

License

Levanter is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.

levanter's People

Contributors

dlwh avatar ivan-zhou avatar mkly avatar patrick-kidger avatar raisin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.