Giter Site home page Giter Site logo

Comments (6)

JohnnyOpcode avatar JohnnyOpcode commented on August 28, 2024

With Colab Pro, the default TPU lib (and JAX) is now at 0.3.25. I jumped thru these hoops as well and have run with

!pip install mesh-transformer-jax/ jax==0.3.15 tensorflow==2.8.2 chex==0.1.4 jaxlib==0.3.15

Your mileage may vary..

Johnny

from mesh-transformer-jax.

mosmos6 avatar mosmos6 commented on August 28, 2024

@rinapch Worked perfect for me. Thank you very much..

from mesh-transformer-jax.

mosmos6 avatar mosmos6 commented on August 28, 2024

@rinapch Worked perfect for me. Thank you very much..

That said, it worked perfectly for fine tuning but not to infer on colab. (It caused optax error)
In order to set up the model, I needed to reverse the requirements as

numpy~=1.19.5
tqdm~=4.45.0
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
fabric~=2.6.0
optax==0.0.6
git+https://github.com/deepmind/dm-haiku
git+https://github.com/EleutherAI/lm-evaluation-harness@c406a62047
ray[default]==1.4.1
jax~=0.2.12
Flask~=1.1.2
cloudpickle~=1.3.0
tensorflow-cpu~=2.5.0
google-cloud-storage~=1.36.2
transformers
smart_open[gcs]
func_timeout
ftfy
fastapi
uvicorn
lm_dataformat
pathy

and

!pip install chex==0.1.2
!pip install jaxlib==0.1.68
!pip install dm-haiku==0.0.5

Just as a note.

from mesh-transformer-jax.

AidanShipperley avatar AidanShipperley commented on August 28, 2024

Thank you so much for this post, it helped me resolve all of my dependency issues. I have never worked with poetry before, but I was able to get a model training in a conda environment just using install commands.

If anybody is interested, I wrote out the steps I took from scratch that are currently working based on my test run.

-- First, Install conda on the TPU vm

mkdir conda_install
cd conda_install
sudo apt-get update
sudo apt-get install wget
wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh
bash Anaconda3-2022.10-Linux-x86_64.sh

-- Update path to include conda

export PATH=~/anaconda3/bin:$PATH

-- Create env with mamba and python == 3.8

conda create -n gpt -c conda-forge mamba python==3.8

-- Close and reopen terminal, ressh

gcloud compute tpus tpu-vm ssh YOUR_TPU_NAME --zone YOUR_ZONE_NAME

-- Leave base

conda deactivate 

-- Enter env

conda activate gpt

-- Install requirements available through conda first

mamba install -c conda-forge numpy==1.19.5 tqdm==4.45.0 einops==0.3.0 requests==2.25.1 fabric==2.6.0 optax==0.0.9 dm-haiku==0.0.5 jax==0.2.18 cloudpickle==1.3.0 tensorflow-cpu==2.6.0 google-cloud-storage==1.36.2 transformers==4.16.2 smart_open==5.2.1 ftfy==6.1.1 pathy==0.10.1 func_timeout==4.3.5

-- Install remaining requirements not available through conda with pip

pip install ray[default]==1.4.1 wandb==0.13.7 chex==0.0.5 lm-dataformat==0.0.20 typing-extensions==4.2.0 protobuf==3.19.5

-- NOTE: You will see a typing-extensions error pop up about tensorflow 2.6.0 not being compatible with 4.2.0. This is fine, ignore it.

-- Jax 0.2.12 does NOT WORK with TPUs anymore, but we can use 0.2.18 or 0.2.20

pip install "jax[tpu]==0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

-- If you have issues with protobuf (may originate from the import wandb call), run this

python3 -m pip uninstall protobuf
python3 -m pip install protobuf==3.19.5

-- Finally, you can run this and fine-tune your model

cd ./mesh-transformer-jax/
python3 device_train.py --config=./configs/YOUR_CONFIG_NAME.json --tune-model-path=gs://YOUR_BUCKET_NAME/step_383500/

from mesh-transformer-jax.

mosmos6 avatar mosmos6 commented on August 28, 2024

@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12.

from mesh-transformer-jax.

JohnnyOpcode avatar JohnnyOpcode commented on August 28, 2024

@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12.

I was using Colab Pro (paid) and I experimented with different versions of the libraries and with pip. The key takeaway is compatibility with the TPUv2 ASIC. I'll try and find some time to go thru those motions again and come up with a newer working requirements.txt for everybody.

Python sucks btw. Just like JS and TS. Too many brittle dependencies, but it does create lots of BS positions and salaries.

from mesh-transformer-jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.