Giter Site home page Giter Site logo

Comments (6)

RUFFY-369 avatar RUFFY-369 commented on July 4, 2024 1

Hi @lxianl455 please put your model in train mode with model.train() before performing a backprop

model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.

Cheers!

from transformers.

RUFFY-369 avatar RUFFY-369 commented on July 4, 2024 1

@lxianl455 okay so if you want to to use the RWKV model just for inference or in default eval() mode and don't want to put it in train mode then modify your code to this, the error will go away:

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=False)
.
.
.

give use_cache a False value because during caching, the model stores intermediate results to speed up computation and it can interfere with gradient computation and also use_cache is not used for training or gradient computation.

Cheers!

from transformers.

amyeroberts avatar amyeroberts commented on July 4, 2024

cc @ArthurZucker

from transformers.

lxianl455 avatar lxianl455 commented on July 4, 2024

Hi @lxianl455 please put your model in train mode with model.train() before performing a backprop

model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.

Cheers!

Actually, I want to combine RWKV block with some other modules to predict time series information. I am not using the whole RWKV model, but only its blocks. In this scenario, the Transformer Trainer cannot be used. How can I solve this backward error?

from transformers.

lxianl455 avatar lxianl455 commented on July 4, 2024

Actually, the function I want is to work like LSTM. When training, LSTM can take in the init state for initialization and can also return the ending state afterwards. https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

I think the use_cache in the RWKV code is not like the use_cache in other models. Its actual function is like asking for the return of LSTM cell state and hidden state. Here is the code (line 300 in modeling_rwkv.py ):

rwkv, layer_state = rwkv_linear_attention( self.time_decay, self.time_first, key, value, state=layer_state, return_state=use_cache, )

from transformers.

ArthurZucker avatar ArthurZucker commented on July 4, 2024

The recurrentGemma model implements something more in the lines of RNN (so close to LSTM) if you are looking for an equivalent)

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.