Comments (6)
Hi @lxianl455 please put your model in train
mode with model.train()
before performing a backprop
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.
Cheers!
from transformers.
@lxianl455 okay so if you want to to use the RWKV model just for inference or in default eval()
mode and don't want to put it in train mode then modify your code to this, the error will go away:
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=False)
.
.
.
give use_cache
a False
value because during caching, the model stores intermediate results to speed up computation and it can interfere with gradient computation and also use_cache
is not used for training or gradient computation.
Cheers!
from transformers.
from transformers.
Hi @lxianl455 please put your model in
train
mode withmodel.train()
before performing a backpropmodel = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True) tokenizer.pad_token = tokenizer.eos_token model.train() . . .
Cheers!
Actually, I want to combine RWKV block with some other modules to predict time series information. I am not using the whole RWKV model, but only its blocks. In this scenario, the Transformer Trainer cannot be used. How can I solve this backward error?
from transformers.
Actually, the function I want is to work like LSTM. When training, LSTM can take in the init state for initialization and can also return the ending state afterwards. https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
I think the use_cache
in the RWKV code is not like the use_cache
in other models. Its actual function is like asking for the return of LSTM cell state and hidden state. Here is the code (line 300 in modeling_rwkv.py ):
rwkv, layer_state = rwkv_linear_attention( self.time_decay, self.time_first, key, value, state=layer_state, return_state=use_cache, )
from transformers.
The recurrentGemma model implements something more in the lines of RNN (so close to LSTM) if you are looking for an equivalent)
from transformers.
Related Issues (20)
- "use_safetensors" not enforced with "local_files_only", loads bin file
- adalomo is not a valid OptimizerNames HOT 1
- flash attention support for chatglm3-6b HOT 1
- [Bug] Modifying normalizer for pretrained tokenizers don't consistently work HOT 1
- Failed to import transformers HOT 4
- Generation with HybridCache fails (affecting Gemma-2) HOT 2
- Qwen2-1.5B eos_token NoneType Error prevents generation HOT 2
- Vit-hybrid is deprecated, however still shown in the official documentation (with broken links) HOT 4
- compute_metric(eval_pred) in trainer is not mini-batch HOT 1
- transformers.pipeline does not load tokenizer passed as string for custom models HOT 1
- Do we need a config to change `padding_side='left` before the evaluation? HOT 3
- Label Leakage in Gemma 2 Finetuning HOT 1
- QLORA + FSDP distributed fine-tuning failed at the end during model saving stage
- Error running inference on CogVLM2 when distributing it on multiple GPUs: Expected all tensors to be on the same device, but found at least two devices HOT 2
- Mismatch with epoch when using gradient_accumulation HOT 1
- AttributeError: 'str' object has no attribute 'shape' HOT 4
- Whisper - list index out of range with word level timestamps HOT 1
- NameError: free variable 'state_dict' referenced before assignment in enclosing scope HOT 3
- Any config for DeBERTa series as decoders for TSDAE? HOT 3
- Unable to load models with adapter weights in offline mode HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.