Comments (6)
@Zoeyyao27 hey! This is occurring because you have to pass in the prev generated text into generate()
along with cache. It will not be used to calculate key-values, but we need that to infer actual seq length and build correct attention_mask
. I modified slightly your code, see below
prefix_list = ["Hello, my name is yy", "What is your name?"]
dialog_history = prefix_list[0]
cache = SinkCache(window_length=1024, num_sink_tokens=4)
for prefix in prefix_list:
dialog_history += prefix
inputs = tokenizer(dialog_history, return_tensors='pt').to(device)
input_length = inputs.input_ids.shape[-1]
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
use_cache=True,
past_key_values=cache,
pad_token_id=tokenizer.pad_token_id,
return_dict_in_generate=True,
)
decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
dialog_history += decoded
cache = gen_out.past_key_values
print(decoded)
from transformers.
@zucchini-nlp please have a go at it :)
from transformers.
A note adding to @zucchini-nlp's comment above: the line cache = gen_out.past_key_values
is not needed. The cache object is updated in-place, the only operation you need to do manually is to instantiate a new cache for a brand new chat/prompt :)
from transformers.
Hmm, right, there's a bug in how we crop input_ids
when continuing generation from SinkCache. @gante will you fix it or I can open a PR later this week.
Yet, I'm not sure how it will git in the current API, we prob need to update caching API soon and make more rigorous tests for all cache types, WDYT?
from transformers.
Thank you for your reply!
However when I use chat model and use tokenizer.apply_chat_template, I would get the following error:
Traceback (most recent call last):
File "/data/yaoy/long_context/repeat_sirllm/main.py", line 35, in <module>
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 1896, in generate
result = self._sample(
File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 2633, in _sample
outputs = self(
File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1162, in forward
outputs = self.model(
File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 938, in forward
causal_mask = self._update_causal_mask(
File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1060, in _update_causal_mask
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
RuntimeError: The size of tensor a (68) must match the size of tensor b (50) at non-singleton dimension 0
Here is the code:
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import SinkCache
model_id = "01-ai/Yi-1.5-6B-Chat"
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, device_map='auto',
torch_dtype=torch.bfloat16,cache_dir="cache")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache",device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
device = model.device
prefix_list = ["Hello, my name is yy", "What is your name?"]
dialog_history = []
cache = SinkCache(window_length=50, num_sink_tokens=4)#,recent_ratio=0.3) #1024
for prefix in prefix_list:
dialog_history.append({"role": "user", "content": prefix})
input_text=tokenizer.apply_chat_template(dialog_history, tokenize=False)
inputs = tokenizer(input_text, return_tensors='pt').to(device)
input_length = inputs.input_ids.shape[-1]
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
use_cache=True,
past_key_values=cache,
pad_token_id=tokenizer.pad_token_id,
return_dict_in_generate=True,
)
decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
dialog_history.append({"role": "assistant", "content": decoded})
print(decoded)
print(cache)
from transformers.
Thank you for your reply!
However when I use chat model and use tokenizer.apply_chat_template, I would get the following error:
Traceback (most recent call last): File "/data/yaoy/long_context/repeat_sirllm/main.py", line 35, in <module> gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64, File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 1896, in generate result = self._sample( File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/generation/utils.py", line 2633, in _sample outputs = self( File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1162, in forward outputs = self.model( File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/yaoy/anaconda/envs/long/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 938, in forward causal_mask = self._update_causal_mask( File "/data/yaoy/long_context/repeat_sirllm/transformers/src/transformers/models/llama/modeling_llama.py", line 1060, in _update_causal_mask causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) RuntimeError: The size of tensor a (68) must match the size of tensor b (50) at non-singleton dimension 0
Here is the code:
import torch from transformers import AutoTokenizer,AutoModelForCausalLM from transformers.cache_utils import SinkCache model_id = "01-ai/Yi-1.5-6B-Chat" model = AutoModelForCausalLM.from_pretrained( model_id, low_cpu_mem_usage=True, device_map='auto', torch_dtype=torch.bfloat16,cache_dir="cache") model = model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache",device_map='auto') tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' device = model.device prefix_list = ["Hello, my name is yy", "What is your name?"] dialog_history = [] cache = SinkCache(window_length=50, num_sink_tokens=4)#,recent_ratio=0.3) #1024 for prefix in prefix_list: dialog_history.append({"role": "user", "content": prefix}) input_text=tokenizer.apply_chat_template(dialog_history, tokenize=False) inputs = tokenizer(input_text, return_tensors='pt').to(device) input_length = inputs.input_ids.shape[-1] gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64, use_cache=True, past_key_values=cache, pad_token_id=tokenizer.pad_token_id, return_dict_in_generate=True, ) decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True) dialog_history.append({"role": "assistant", "content": decoded}) print(decoded) print(cache)
In fact, I don't think the apply_chat_template cause the problem. If I use a smaller window_length in SinkCache, I would get the same error. Here is the code:
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import SinkCache
model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, device_map='auto',
torch_dtype=torch.bfloat16,cache_dir="cache")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,cache_dir="cache",device_map='auto')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
device = model.device
prefix_list = ["Hello, my name is yy", "What is your name?"]
dialog_history = prefix_list[0]
cache = SinkCache(window_length=20, num_sink_tokens=4)
for prefix in prefix_list:
dialog_history += prefix
inputs = tokenizer(dialog_history, return_tensors='pt').to(device)
input_length = inputs.input_ids.shape[-1]
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=64,
use_cache=True,
past_key_values=cache,
pad_token_id=tokenizer.pad_token_id,
return_dict_in_generate=True,
)
decoded = tokenizer.decode(gen_out.sequences[0, input_length:], skip_special_tokens=True)
dialog_history += decoded
cache = gen_out.past_key_values
print(decoded)
from transformers.
Related Issues (20)
- MPS support broken for T5 models HOT 1
- Pass `HFQuantizer` to `from_pretrained` kwargs HOT 1
- [i18n-<languageCode>] Translating docs to <languageName> HOT 1
- NumPy 2.0 support HOT 1
- Can I use "attn_implementation" in model config file HOT 4
- Encountering an error while loading a model using state_dict and quantization simultaneously HOT 6
- Fix 'Can't infer missing attention mask on `mps` device' HOT 4
- might be a waste of resources HOT 1
- Tensors' device passed to a model is not correct when ACCELERATE_TORCH_DEVICE is privateuseone
- Suport sdpa for RoBERTa and XLM-RoBERTa models
- Converting gguf fp16 & bf16 to hf is not supported. HOT 5
- Dead code, `cache_kwargs` HOT 3
- The conversion of the llama3 model back from gguf seems weird. HOT 5
- Train on logits instead of one hot vectors
- 'tf_keras' has no attribute 'activations' HOT 4
- Bug in whisper word-level timestamps (`tokenizer._decode_asr`) HOT 1
- RobertaForClassification throws an error because of dimension mismatch HOT 1
- Fix Bug: Gemma2 the `past_key_value.update()` function has added a new parameter "sliding_window" to support the `_sliding_update` function.
- Moondream breaks on hf 4.42+
- Transformer models are not deterministic when using Flash Attention 2 HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.