Giter Site home page Giter Site logo

Comments (9)

dlwh avatar dlwh commented on August 15, 2024

oof, thanks for the report. We also don't currently use jax 0.4.28 ourselves but 0.4.26, so there could be an issue there? Given that the hang is in _value I'm guessing it's an incompatibility with the latest JAX. You could also try --model.use_flash_attention=false to see if something went wrong with the TPU splash attention kernel. (We started using it recently and it could be there's an incompatibility with latest version of it.)

Can you just verify the commit hash? We fixed a couple of things recently and want to be sure, though the logging isn't consistent with those.

The two issues are a timeout in saving checkpoints inside tensorstore would cause a hang or crash, depending, and the other is I missed an error case when processing online. You're not seeing any messages about waiting for chunks?

I will try to repro soon!

Could you also try our setup scripts, just to see: https://github.com/stanford-crfm/levanter/blob/main/docs/Getting-Started-TPU-VM.md (They're kinda janky but they work pretty well) They don't currently use queued-resources but you could at least run infra/helpers/setup-tpu-vm.sh on each machine.

Good to know about pip non-determinism. I didn't realize that! We should update to something more sane then.

For whatever reason, the TPU splash attention kernel only works with head dim 128, so we use my "plain JAX" version if that fails

from levanter.

rjpower avatar rjpower commented on August 15, 2024

FWIW, I tried disabling flash attention in the model with:

model:
  use_flash_attention: False
  type: llama

But I still see references to flash attention in the logs, e.g.

 flash_attention.cc:1872] number of matches: 0

from levanter.

dlwh avatar dlwh commented on August 15, 2024

I'm guessing disabling FA didn't fix in that case?

I've never seen the flash_attention.cc thing, but if I had to guess it's a compiler optimization pass in XLA?

from levanter.

rjpower avatar rjpower commented on August 15, 2024

Thanks for the quick reply. Yeah, disabling flash didn't seem to help -- I tried editing the code just in case it was a configuration issue but got the same result.

You're right, the message is probably from an XLA pass. I wouldn't be surprised if XLA is trying to discover and automatically turn on Flash attention or something similar. I tried switching JAX back to 0.4.26, but I still saw it stalling. It's possible my environment is messed up on the VMs at this point though. I'm trying the infra script now to get a clean set of TPUs and see if that helps.

from levanter.

dlwh avatar dlwh commented on August 15, 2024

(if you want to do a live debug ping my email which is GH handle @ stanford

from levanter.

rjpower avatar rjpower commented on August 15, 2024

Sorry for the delay -- using infra setup worked! Thanks for the tip:

34.125.20.95: 2024-05-15T21:52:45 - 0 - levanter.eval - eval.py:120 - INFO :: eval loss: 7.331
34.125.183.145: 2024-05-15T21:52:45 - 3 - levanter.eval - eval.py:120 - INFO :: eval loss: 7.331
34.125.183.145: 2024-05-15T21:52:45 - 3 - levanter.eval - eval.py:131 - INFO ::  loss: 7.331
34.125.20.95: 2024-05-15T21:52:45 - 0 - levanter.eval - eval.py:131 - INFO ::  loss: 7.331
34.16.196.50: 2024-05-15T21:52:45 - 1 - levanter.eval - eval.py:120 - INFO :: eval loss: 7.331
34.125.242.121: 2024-05-15T21:52:45 - 2 - levanter.eval - eval.py:120 - INFO :: eval loss: 7.331
34.125.242.121: 2024-05-15T21:52:45 - 2 - levanter.eval - eval.py:131 - INFO ::  loss: 7.331
34.16.196.50: 2024-05-15T21:52:45 - 1 - levanter.eval - eval.py:131 - INFO ::  loss: 7.331
train:   0%|          | 201/100000 [13:39<99:38:53,  3.59s/it, loss=6.86]
train:   0%|          | 201/100000 [13:39<99:38:53,  3.59s/it, loss=6.86]
train:   0%|          | 201/100000 [13:39<99:38:55,  3.59s/it, loss=6.86]

Now I'm curious, I didn't realize the progress bar was only updating after 50 steps -- I'm going to retry with the original slice to see if I perhaps was just not waiting long enough... :/

from levanter.

rjpower avatar rjpower commented on August 15, 2024

Hrm, seems to be working on the original slice with 0.4.26 and 0.4.28. Sorry for the noise!

It appears the progress bar is only being updated when we transition to eval: I originally adjusted steps_per_eval to reduce the amount of time per eval and probably exacerbated this issue. With this set to a high number, it looked like the system was hanging. I suspect because everything was working fine, but because the computation is so TPU bound, every time I tried to capture a stack trace, it would be waiting on PJRT.

It looks like the code should be updating the progress on each step, but I'm guessing there's some sort of buffering issue happening (with Ray?). When I added extra logging to the hook callback I see it being called in one big dump at the end once eval starts:

34.16.214.62: Running hooks:  5 1
34.16.228.1: Running hooks:  5 1
34.16.228.1: Running hooks:  5 2
34.16.149.229: Running hooks:  5 1
34.16.149.229: Running hooks:  5 2
34.16.149.229: Running hooks:  5 3
34.16.149.229: Running hooks:  5 4

I'll quickly try to see if some generous use of stdout.flush will help at all.

from levanter.

dlwh avatar dlwh commented on August 15, 2024

Yeah it does update after each step, but it really depends on what ssh-y tool you're using. gcloud's tpu ssh updates continuously, but IIRC pdssh doesn't? (I used to use pdssh but stopped at some point). Probably just buffering, like you say.

And no worries about noise!

from levanter.

rjpower avatar rjpower commented on August 15, 2024

Thanks for the help, I sent a PR to remove the tip about pdsh in the getting started. It's nice to use, but in the interest of having something that works reliably... :)

from levanter.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.