Comments (10)
@danielhanchen the inv_freq
permanent buffer can be casted with .to
model casting, e.g.
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = model.to(device="cuda", dtype=torch.bfloat16)
print(model.model.layers[0].self_attn.rotary_emb.inv_freq.dtype)
On Llama and Gemma that's no problem, since we're recently updated the code to cast inv_freq
to float()
before it is applied to get sin
and cos
(e.g. here). However, other RoPE models like Mistral have yet to receive the same treatment.
We'll gladly take PRs to fix it ;) We will be touching the other RoPE models soon anyways, to migrate them to a Llama-like structure (which, contrarily to other models, is compatible with torch.compile
)
from transformers.
Do you want to open a PR to propagate the changes we made to Llama and gemma?
from transformers.
cc @gante
from transformers.
@avnermay I'm not too certain, but I think inv_freq
will always be calculated in float32. For eg Gemma:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim))
And for Llama:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
The downcast only applies to matrix multiplications and explicit downcasts like what I found what they did in Keras.
I haven't ran the code to confirm, but it would be great if you can print the dtype during a finetuning run to confirm inv_freq
is actually bfloat16.
from transformers.
@gante Whoops sorry just saw this - apologies!
Oh fair points on this! Hmm is there like some sort of lockin mechanism to not allow the conversion to occur? Maybe some sort of overriding mechanism ie write over tensor.to
itself
from transformers.
Why not use the approach taken by the other models, that force inv_freq to be float32? The key is avoiding cases where cos and sin are recomputed using a low-precision inv_freq
tensor. This occurs (for example) during mixed precision training, because inv_freq was automatically downcast to bfloat16 in that case.
from transformers.
@danielhanchen the only solution is to explicitly upcast 😬 some frameworks like deepspeed explicitly can hijack tensor creation and force them to be initialized in a certain type (which has also caused issues with RoPE).
@avnermay that is the solution. The change is simple, but we are working on other overlapping problems -- bear with us 🤗
from transformers.
Just commenting on this so that it is not marked as stale. Thanks!
from transformers.
#30642 will fix this ! 🤗
from transformers.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
from transformers.
Related Issues (20)
- Etherscanio
- Erc
- Ww
- Wws HOT 1
- Speculative Decoding for chunked audios HOT 2
- Original Llama-3 tokenizer behaves differently from `transformers` version HOT 2
- MusicGen fails when being used with pipeline HOT 1
- convert Fairseq to huggingface error HOT 1
- Saving the model from checkpoint
- How to get back the input time series after using PatchTSTForPretraining?
- Loss function defintion in the BertForSequenceClassification
- Some weights of BlipModel were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base. HOT 1
- Control flow issue with symbolic_trace when using inputs_embeds in MistralForCausalLM HOT 2
- RecurrentGemma Doesn't Support left padding? HOT 3
- Add Mamba2 HOT 5
- Speed up image processors - cast to array before BatchFeature
- PreTrainedModel.from_pretrained(path, from_flax=True) fails for sharded Flax checkpoints
- [pipeline] VQA pipeline does not accept list as input HOT 1
- Loading XGLM with Tensorflow and apply resize_token_embeddings() raises an error. HOT 2
- Batch size schedulers HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.