Comments (8)
The PyTorch equivalent to your reset_states
function looks like this:
def reset_states(model):
for layer in model.modules():
if hasattr(layer, 'reset_states'): # could use `isinstance` instead to check if it's an RNN layer
layer.reset_states()
I think the right answer here is to use a generic stateful wrapper for RNNs. It could be as simple as:
class StatefulRNN(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.state = None
def reset_states(self):
self.state = None
def forward(self, x, state=None, lengths=None):
if state is not None:
self.state = state
print(f'Using state {self.state}')
y, state = self.layer(x, self.state, lengths)
self.state = self._detach_state(state)
return y, state
def _detach_state(self, state):
if isinstance(state, tuple):
return tuple(s.detach() for s in state)
if isinstance(state, list):
return [s.detach() for s in state]
return state.detach()
You could wrap any of the Haste RNN layers with this StatefulRNN
decorator class and you'd get the stateful behavior you're looking for.
Here's a complete example:
import torch
import torch.nn as nn
import haste_pytorch as haste
class StatefulRNN(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.state = None
def reset_states(self):
self.state = None
def forward(self, x, state=None, lengths=None):
if state is not None:
self.state = state
print(f'Using state {self.state}')
y, state = self.layer(x, self.state, lengths)
self.state = self._detach_state(state)
return y, state
def _detach_state(self, state):
if isinstance(state, tuple):
return tuple(s.detach() for s in state)
if isinstance(state, list):
return [s.detach() for s in state]
return state.detach()
def reset_states(model):
for layer in model.modules():
if hasattr(layer, 'reset_states'): # could use `isinstance` instead to check if it's an RNN layer
layer.reset_states()
SEQ_LEN = 250
BATCH_SIZE = 10
INPUT_SIZE = 3
HIDDEN_SIZE = 5
lstm = haste.GRU(INPUT_SIZE, HIDDEN_SIZE) # or haste.LSTM or ...
model = StatefulRNN(lstm)
learning_rate = 1e-3
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10):
x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
if t % 3 == 0:
reset_states(model)
y, _ = model(x)
loss = loss_fn(y, torch.zeros_like(y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Step {t}, loss {loss}')
from haste.
From what I can tell, stateful
in Keras-land is a flag that aids in implementing truncated backprop through time. It stores the last state
of the RNN layer for each batch item and uses that as the initial state for the next batch that passes through the RNN. Is that correct? How does the state get reset when an entirely unrelated batch of data comes along?
from haste.
@sharvil Not a backprop mechanism; gradients do not flow between batches, including for stateful=True
. The rest of your description is correct; details below.
The states are reset to zero per layer via layer.
reset_states()
- or for all layers in the model via model.reset_states()
; if reset_states()
isn't called, the states do not reset themselves to zero. States are built and initialized via an overridable method, get_initial_state()
- e.g. tailored for LSTM (using another overridable method).
So basically, implementing requires building a dedicated tensor that captures the last timestep's hidden state, and is resettable to zero via a method.
When and how does LSTM "pass states" in stateful? Pasting relevant excerpts from here:
- When: only batch-to-batch; samples are entirely independent
- How: in Keras, only batch-sample to batch-sample:
stateful=True
requires you to specifybatch_shape
instead ofinput_shape
- because, Keras buildsbatch_size
separate states of the LSTM at compiling
Per above, this should not be done:
# sampleNM = sample N at timestep(s) M
batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample21, sample41, sample11, sample31]
This implies 21
causally follows 10
- and will wreck training. Instead do:
batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample11, sample21, sample31, sample41]
from haste.
As for how an "entirely unrelated batch" is detected in practice, can be done via callbacks w/ a counter in on_batch_end()
; e.g. call reset_states()
every 4th batch - or a custom train loop (my approach).
from haste.
Thanks so much for the detailed writeup. I think this is fairly easy to achieve in PyTorch with the existing Haste API. Here's some sample code that demonstrates what a stateful LSTM would look like:
import torch
import torch.nn as nn
import haste_pytorch as haste
class StatefulLSTM(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.state = None
def forward(self, x, state=None, lengths=None, reset=False):
if reset:
self.state = state
print(f'Using state {self.state}')
y, state = self.layer(x, self.state, lengths)
self.state = (state[0].detach(), state[1].detach())
return y, state
SEQ_LEN = 250
BATCH_SIZE = 1
INPUT_SIZE = 3
HIDDEN_SIZE = 5
lstm = haste.LSTM(INPUT_SIZE, HIDDEN_SIZE)
model = StatefulLSTM(lstm)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(10):
x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
y, _ = model(x, reset=(t % 3 == 0))
loss = loss_fn(y, torch.zeros_like(y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Step {t}, loss {loss}')
This code trains an LSTM on sequences of length 750 which are fed in chunks of 250. The LSTM state is squirreled away in self.state
of StatefulLSTM
and is cleared whenever reset=True
is passed in to forward(...)
.
Does this solution work for your application?
from haste.
Thanks for the demo. From what I understand from lstm.py, state=None
triggers resetting the cell and hidden states to zero, which StatefulLSTM
enables via an additional reset
argument. Indeed, that's accurate. However, it's unclear how to reset states if the layers are part of a larger network (e.g. with conv layers). I don't suppose model.reset_states()
is doable, unless modifying model
class instance; a workaround is something like:
def reset_states(model):
for layer in model.layers: # tf.keras
if hasattr(layer, 'reset_states'):
layer.reset_states()
but I can't tell from your snippet how one would reset states before passing an input.
Also, it'd work best if stateful was part of the base implementation, rather than a dedicated one, else it'll complicate extending functionality (e.g. will also need StatefulLayerNormLSTM
).
from haste.
@sharvil Looks excellent - a wrapper is a fair-enough alternative. To be sure all works as expected, I'll get back to you once I get haste_pytorch working.
from haste.
@sharvil Ran the script - looks good; I'll compare it more extensively vs. Keras later, but currently all seems to work as expected. Thank you.
from haste.
Related Issues (20)
- Install on pip on systems without cuda HOT 7
- Segmentation fault on Cuda 10.0 HOT 2
- Support zoneout on lstm cell state and add recurrent dropout HOT 2
- CUDA error: an illegal memory access was encountered HOT 6
- haste_pytorch: Gradient for kernel/recurrent_kernel becomes zero when trained on gpu HOT 4
- How to expose LayerNormGRUCell to python ? HOT 2
- Can't run haste layers in Keras HOT 12
- Biases in final IndRNN layer are 0 HOT 1
- Zoneout remains during eval() HOT 2
- return_state_sequence for tf version
- layer_norm_gru_cell HOT 1
- Can Bidirectional Rnn and multi-layer Rnn be supported? HOT 1
- Activation function in IndRNN HOT 1
- haste_pytorch does not install properly with conda cudatoolkit? HOT 3
- Feature request for cell classes for pytorch HOT 7
- `RNN`s with `zoneout > 0.0` have wrong gradients HOT 1
- haste_tf compilation fails with "‘bfloat16’ in namespace ‘Eigen’ does not name a type"
- Support for PyTorch packed sequences HOT 2
- Supporting RWKV (a RNN that can match transformer LM & zero-shot performance at 1B+ params)
- Nan loss when replace pytorch LSTM with your LSTM or LayerNormLSTM HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from haste.