Giter Site home page Giter Site logo

Comments (17)

mobicham avatar mobicham commented on May 22, 2024 1

Thanks @ArthurZucker !
Is there a way to replicate this use_cache=False behavior but when manually writing the generate function, like this one (based on your code): https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py

The reason is because it's better to compile the decode_one_token function instead of the whole forward pass, to avoid annoying compilation everytime the input prompt shape changes.

I guess here pass use_cache=False? https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L72

@gante I get that sometimes as well. I think it's a bit better with torch nightly build.

from transformers.

mobicham avatar mobicham commented on May 22, 2024 1

@ArthurZucker thanks! I found a hack: warm-up with use_cache=False the very first time you compile, then use_cache=True for generation. It still needs to warm-up again with use_cache=True but at least the output is correct.

Update: the warm-up with the full torch.compile takes a lot of VRAM. The best would be to make it work with decode_one_token. Still haven't found a proper way of doing it.

There's another problem: if you compile using max_new_tokens=100 for example and use max_new_tokens=1000 after the warm-up, you get RuntimeError: CUDA error: device-side assert triggered. The trick is to use a larger max_new_tokens at compilation time, then it works with any value less than that.

model.forward = torch.compile(model.forward, **{"mode":"reduce-overhead", "fullgraph":True})

prompt = "Write an essay about large language models."

# warm-up
for _ in range(10):
	gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)

prompt = "How do I make a cake"
import time
t1 = time.time()
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=True)
t2 = time.time()
print(len(gen_out[0])/(t2-t1), "tokens/sec")

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024 1

Wow thanks a lot for all this valuable debugging, would really love to fix this!

from transformers.

amyeroberts avatar amyeroberts commented on May 22, 2024

cc @gante

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

Super weird, and I can indeed reproduce.
The fix is use_cache=False. It's counter intuitive, but this will work:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id  = "meta-llama/Llama-2-7b-chat-hf"
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id) 
tokenizer.add_bos_token = False

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
	gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))

inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))

from transformers.

gante avatar gante commented on May 22, 2024

When trying the original script with torch==2.4.0.dev20240418+cu121 I get a Aborted (core dumped) preceded by RuntimeError: Triton Error [CUDA]: device-side assert triggered and a bunch of out of bounds memory access 👀

@ArthurZucker's suggested script gets the same exceptions (because the first calls hit the same issue)

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

Ah! I did not get that and successfully generated, no idea what went wrong with yours

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024
image that''s what I got and it was pretty fast

from transformers.

mobicham avatar mobicham commented on May 22, 2024

@ArthurZucker use_cache=False is not really a solution, the speed is much slower vs. use_cache=True.
I was not able to make it work properly by setting use_cache=False directly in the model forward pass either.

@gante that cuda issue mainly happens when you compile the whole forward pass, normally you only need to compile the forward pass for the decoding part only (input is 1 token and fixed), not the prefill.

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

@mobicham, normally you should not have this issue with the script that compiles decode_one_token. I pushed a fix to main that should have solved this: #30380, which was probably not overwriting the cache.
I think reset_cache might not work as expected

from transformers.

mobicham avatar mobicham commented on May 22, 2024

Was not able to test the fix because there's another problem with 4.41.0: #30417

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

Super weird and we'll fix it asap

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

Might be related to #30414 as well

from transformers.

mobicham avatar mobicham commented on May 22, 2024

I was finally able to make it work without blowing up the VRAM:

  1. Compile with inputs of size [batch_size, 1]: https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L57-L72
  2. Warm up with 3 prompts with use_cache=False

With this approach, a 4-bit Llama2-7B takes ~5.6GB of runtime with a max 1024 cache size.
If I try the same with model.generate() I run out of VRAM after the 2nd or 3-rd warm-up prompt.

The only issue is the speed. With the approach above I get 165 tokens/sec, it should to be ~205 tokens/sec.

Update: the speed depends on the size of the initialized cache for some reason.

Update 2: It is actually not fixed, the outputs still mix some outputs from previous results.
Will try the fix as soon as a #30417 is fixed

from transformers.

mobicham avatar mobicham commented on May 22, 2024

Thanks @ArthurZucker
I spent the whole day playing with this, the latest version is here . Here's what I noticed so far:

  • For the warm-up, you need to feed it different prompts sequentially, you need at least 3, meaning you need to do generate(prompt1), generate(prompt2), generate(prompt3). If you don't do that and use the same prompt, the cache get totally locked with prompt1
  • Normally, you need to reset the cache before each generation. However, with the compiled version, it crashes if you reset it. When I warm-up the compilation with small 1-token inputs, the output still looks a bit strange, so the cache contains information from some previous prompts. Even if I manually delete and re-create it, same issue.
  • Cache sizes need to be powers of 2, otherwise it crashes with RuntimeError: CUDA error: device-side assert triggered

from transformers.

ArthurZucker avatar ArthurZucker commented on May 22, 2024

BTW we are gonna move with #30476

from transformers.

mobicham avatar mobicham commented on May 22, 2024

Thank you for the update!

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.