Giter Site home page Giter Site logo

axlearn's Introduction

The AXLearn Library for Deep Learning

This library is under active development and the API is subject to change.

Table of Contents

Section Description
Introduction What is AXLearn?
Getting Started Getting up and running with AXLearn.
Concepts Core concepts and design principles.
CLI User Guide How to use the CLI.
Infrastructure Core infrastructure components.

Introduction

AXLearn is a library built on top of JAX and XLA to support the development of large-scale deep learning models.

AXLearn takes an object-oriented approach to the software engineering challenges that arise from building, iterating, and maintaining models. The configuration system of the library lets users compose models from reusable building blocks and integrate with other libraries such as Flax and Hugging Face transformers.

AXLearn is built to scale. It supports the training of models with up to hundreds of billions of parameters across thousands of accelerators at high utilization. It is also designed to run on public clouds and provides tools to deploy and manage jobs and data. Built on top of GSPMD, AXLearn adopts a global computation paradigm to allow users to describe computation on a virtual global computer rather than on a per-accelerator basis.

AXLearn supports a wide range of applications, including natural language processing, computer vision, and speech recognition and contains baseline configurations for training state-of-the-art models.

Please see Concepts for more details on the core components and design of AXLearn, or Getting Started if you want to get your hands dirty.

axlearn's People

Contributors

alex8937 avatar altimofeev avatar amcw7777 avatar apghml avatar ethanlm avatar fnan avatar gyin94 avatar haijingfu avatar jianyuwangv avatar jiarui-lu2 avatar jinhaolei avatar jiya-zhang avatar kelvin-zou avatar madrob avatar markblee avatar ruomingp avatar snehanb avatar swiseman avatar taolei87 avatar tgunter avatar tombstone avatar tuzhucheng avatar weiliu89 avatar wwu137 avatar xianzhidu avatar ya5ut avatar yqwangustc avatar zbwglory avatar zhiyun avatar zhzhyi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

axlearn's Issues

BUG: axlearn.common.utils.as_tensor calls .numpy() which doesn't work with bfloat16

I noticed this problem when I was developing Medusa+, which depends on axlearn. Here is the raw trace of the error when I am converting an Ajax tensor in bfloat16 to torch.tensor: https://rio.apple.com/projects/ai-medusa-plus/pipeline-specs/ai-medusa-plus-unit_tests/pipelines/f4345942-bddc-47ba-ba51-f3c1019290d3/log#L718-L738

The problem is that numpy does not support bfloat16, so if the source tensor is bfloat16, the call to .numpy() would assert.

The following script reveals this problem:

import torch
import numpy
import jax.numpy as jnp

x = torch.rand(
    (1,),
    dtype=torch.float32,
)
print(x.numpy())

x = torch.rand(
    (1,),
    dtype=torch.float16,
)
print(x.numpy())

x = torch.rand(
    (1,),
    dtype=torch.bfloat16,
)
print(x.numpy())

The first calls to .numpy() would succeed; however, the last would fail.

22:36 $ python3 medusa_plus/numpy_bf16.py
[0.686854]
[0.6177]
Traceback (most recent call last):
  File "/mnt/medusa-plus/medusa_plus/numpy_bf16.py", line 21, in <module>
    print(x.numpy())
TypeError: Got unsupported ScalarType BFloat16

configure gcloud when configuring axlearn

I ran axlearn gcp config activate to activate a project in us-west1. However, this doesn't change the config for my gcloud, as evident when I ran gcloud config list.

If I don't manually change the gcloud config, then I ran into this error:
E0319 20:55:05.981639 140440028389440 config.py:118] Unknown settings for project=<project> and zone=us-west1-a; You may want to configure this project first; Please refer to the docs for details.
When using this command: axlearn gcp dataflow start ...

Missing googleapiclient as dependency

I get ModuleNotFoundError: No module named 'googleapiclient' when calling:

from axlearn.cloud.gcp.vm import ..

Seems that google-api-python-client needs to be added as a dependency in axlearn?

axlearn on GPU started failing during init after upgrade

This is the error message I see when launching like this:

timeout -k 60s 900s python3 -m axlearn.common.launch_trainer_main --module=gke_fuji --config=fuji-7B-b512-fsdp8 --trainer_dir=/tmp/test_trainer --data_dir=gs://axlearn-public/tensorflow_datasets --jax_backend=gpu --num_processes=8 --distributed_coordinator=stoelinga-may13-1-j-0-0.stoelinga-may13-1 --process_id=0 --trace_at_steps=25

Error message:

2024-05-13 16:17:05.732984: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make
sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.ten
sorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 16, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 10, in main
    launch.setup()
  File "/tmp/axlearn/axlearn/common/launch.py", line 92, in setup
    setup_spmd(
  File "/tmp/axlearn/axlearn/common/utils_spmd.py", line 118, in setup
    jax.distributed.initialize(**init_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/distributed.py", line 196, in initialize
    global_state.initialize(coordinator_address, num_processes, process_id,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/distributed.py", line 72, in initialize
    default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
IndexError: list index out of range

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.