Giter Site home page Giter Site logo

Comments (10)

gante avatar gante commented on June 15, 2024 1

@danielhanchen the inv_freq permanent buffer can be casted with .to model casting, e.g.

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = model.to(device="cuda", dtype=torch.bfloat16)
print(model.model.layers[0].self_attn.rotary_emb.inv_freq.dtype)

On Llama and Gemma that's no problem, since we're recently updated the code to cast inv_freq to float() before it is applied to get sin and cos (e.g. here). However, other RoPE models like Mistral have yet to receive the same treatment.

We'll gladly take PRs to fix it ;) We will be touching the other RoPE models soon anyways, to migrate them to a Llama-like structure (which, contrarily to other models, is compatible with torch.compile)

from transformers.

ArthurZucker avatar ArthurZucker commented on June 15, 2024

Do you want to open a PR to propagate the changes we made to Llama and gemma?

from transformers.

ArthurZucker avatar ArthurZucker commented on June 15, 2024

cc @gante

from transformers.

danielhanchen avatar danielhanchen commented on June 15, 2024

@avnermay I'm not too certain, but I think inv_freq will always be calculated in float32. For eg Gemma:

self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim))

And for Llama:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))

The downcast only applies to matrix multiplications and explicit downcasts like what I found what they did in Keras.

I haven't ran the code to confirm, but it would be great if you can print the dtype during a finetuning run to confirm inv_freq is actually bfloat16.

from transformers.

danielhanchen avatar danielhanchen commented on June 15, 2024

@gante Whoops sorry just saw this - apologies!

Oh fair points on this! Hmm is there like some sort of lockin mechanism to not allow the conversion to occur? Maybe some sort of overriding mechanism ie write over tensor.to itself

from transformers.

avnermay avatar avnermay commented on June 15, 2024

Why not use the approach taken by the other models, that force inv_freq to be float32? The key is avoiding cases where cos and sin are recomputed using a low-precision inv_freq tensor. This occurs (for example) during mixed precision training, because inv_freq was automatically downcast to bfloat16 in that case.

from transformers.

gante avatar gante commented on June 15, 2024

@danielhanchen the only solution is to explicitly upcast 😬 some frameworks like deepspeed explicitly can hijack tensor creation and force them to be initialized in a certain type (which has also caused issues with RoPE).

@avnermay that is the solution. The change is simple, but we are working on other overlapping problems -- bear with us 🤗

from transformers.

avnermay avatar avnermay commented on June 15, 2024

Just commenting on this so that it is not marked as stale. Thanks!

from transformers.

ArthurZucker avatar ArthurZucker commented on June 15, 2024

#30642 will fix this ! 🤗

from transformers.

github-actions avatar github-actions commented on June 15, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.