Comments (17)
Thanks @ArthurZucker !
Is there a way to replicate this use_cache=False
behavior but when manually writing the generate function, like this one (based on your code): https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py
The reason is because it's better to compile the decode_one_token
function instead of the whole forward pass, to avoid annoying compilation everytime the input prompt shape changes.
I guess here pass use_cache=False
? https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L72
@gante I get that sometimes as well. I think it's a bit better with torch nightly build.
from transformers.
@ArthurZucker thanks! I found a hack: warm-up with use_cache=False
the very first time you compile, then use_cache=True
for generation. It still needs to warm-up again with use_cache=True
but at least the output is correct.
Update: the warm-up with the full torch.compile takes a lot of VRAM. The best would be to make it work with decode_one_token
. Still haven't found a proper way of doing it.
There's another problem: if you compile using max_new_tokens=100
for example and use max_new_tokens=1000
after the warm-up, you get RuntimeError: CUDA error: device-side assert triggered
. The trick is to use a larger max_new_tokens
at compilation time, then it works with any value less than that.
model.forward = torch.compile(model.forward, **{"mode":"reduce-overhead", "fullgraph":True})
prompt = "Write an essay about large language models."
# warm-up
for _ in range(10):
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
prompt = "How do I make a cake"
import time
t1 = time.time()
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=True)
t2 = time.time()
print(len(gen_out[0])/(t2-t1), "tokens/sec")
from transformers.
Wow thanks a lot for all this valuable debugging, would really love to fix this!
from transformers.
cc @gante
from transformers.
Super weird, and I can indeed reproduce.
The fix is use_cache=False
. It's counter intuitive, but this will work:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_bos_token = False
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))
inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))
from transformers.
When trying the original script with torch==2.4.0.dev20240418+cu121
I get a Aborted (core dumped)
preceded by RuntimeError: Triton Error [CUDA]: device-side assert triggered
and a bunch of out of bounds memory access 👀
@ArthurZucker's suggested script gets the same exceptions (because the first calls hit the same issue)
from transformers.
Ah! I did not get that and successfully generated, no idea what went wrong with yours
from transformers.
from transformers.
@ArthurZucker use_cache=False
is not really a solution, the speed is much slower vs. use_cache=True
.
I was not able to make it work properly by setting use_cache=False
directly in the model forward pass either.
@gante that cuda issue mainly happens when you compile the whole forward pass, normally you only need to compile the forward pass for the decoding part only (input is 1 token and fixed), not the prefill.
from transformers.
@mobicham, normally you should not have this issue with the script that compiles decode_one_token
. I pushed a fix to main that should have solved this: #30380, which was probably not overwriting the cache.
I think reset_cache
might not work as expected
from transformers.
Was not able to test the fix because there's another problem with 4.41.0: #30417
from transformers.
Super weird and we'll fix it asap
from transformers.
Might be related to #30414 as well
from transformers.
I was finally able to make it work without blowing up the VRAM:
- Compile with inputs of size [batch_size, 1]: https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L57-L72
- Warm up with 3 prompts with
use_cache=False
With this approach, a 4-bit Llama2-7B takes ~5.6GB of runtime with a max 1024 cache size.
If I try the same with model.generate()
I run out of VRAM after the 2nd or 3-rd warm-up prompt.
The only issue is the speed. With the approach above I get 165 tokens/sec, it should to be ~205 tokens/sec.
Update: the speed depends on the size of the initialized cache for some reason.
Update 2: It is actually not fixed, the outputs still mix some outputs from previous results.
Will try the fix as soon as a #30417 is fixed
from transformers.
Thanks @ArthurZucker
I spent the whole day playing with this, the latest version is here . Here's what I noticed so far:
- For the warm-up, you need to feed it different prompts sequentially, you need at least 3, meaning you need to do generate(prompt1), generate(prompt2), generate(prompt3). If you don't do that and use the same prompt, the cache get totally locked with prompt1
- Normally, you need to reset the cache before each generation. However, with the compiled version, it crashes if you reset it. When I warm-up the compilation with small 1-token inputs, the output still looks a bit strange, so the cache contains information from some previous prompts. Even if I manually delete and re-create it, same issue.
- Cache sizes need to be powers of 2, otherwise it crashes with
RuntimeError: CUDA error: device-side assert triggered
from transformers.
BTW we are gonna move with #30476
from transformers.
Thank you for the update!
from transformers.
Related Issues (20)
- torchrun breaks with load_model_at_end and with metric_for_best_model=eval_f1 on question_answering example
- Resuming from checkpoint runs into OOM HOT 1
- `SPMConverter` does not always add the user defined symbol -> slow fast is thus not equivalent
- Unexpected keyword argument 'encoder_hidden_states' in VisionEncoderDecoder models HOT 3
- Using this command(optimum-cli export onnx --model Qwen1.5-0.5B-Chat --task text-generation Qwen1.5-0.5B-Chat_onnx/) to perform onnx transformation, it is found that the tensor type of the model becomes int64. How to solve this problem? HOT 2
- Unable to run generation tests for Mamba & Jamba models HOT 1
- RecurrentGemma not compatible with autocast / AMP training
- Whisper for audio classification: input_values not present in features HOT 5
- Cannot import name 'WhisperForAudioClassification -Already installed transformers==4.40.2 HOT 4
- Rewriting usage of `torch.bucketize` with more elementary functions
- Significant performance degradation with multi-GPU training on newer torch/transformers HOT 10
- [BLIP2] BLIP2QFormerLayer is missing the self.intermediate parameter, which makes training from scratch impossible
- ImportError: cannot import name 'PaliGemmaForConditionalGeneration' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/__init__.py) HOT 2
- Mamba: `use_cache` is not passed through in `prepare_inputs_for_generation` HOT 1
- scores_for_ground_truths Error for deepset/roberta-base-squad2 model and squad_v2 dataset HOT 1
- Enabling device_map="auto" for Video-LLaVA HOT 1
- Cache problem while runing on multiple nodes with GPU HOT 3
- Improving memory efficiency further 🚀
- Unsuppressable warning: "<model> will not detect padding tokens in `inputs_embeds`" HOT 1
- HF's llama implementation is different from meta-llama's HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.