Giter Site home page Giter Site logo

Stateful in Pytorch about haste HOT 8 CLOSED

lmnt-com avatar lmnt-com commented on May 27, 2024
Stateful in Pytorch

from haste.

Comments (8)

sharvil avatar sharvil commented on May 27, 2024 1

The PyTorch equivalent to your reset_states function looks like this:

def reset_states(model):
  for layer in model.modules():
    if hasattr(layer, 'reset_states'):  # could use `isinstance` instead to check if it's an RNN layer
      layer.reset_states()

I think the right answer here is to use a generic stateful wrapper for RNNs. It could be as simple as:

class StatefulRNN(nn.Module):
  def __init__(self, layer):
    super().__init__()
    self.layer = layer
    self.state = None

  def reset_states(self):
    self.state = None

  def forward(self, x, state=None, lengths=None):
    if state is not None:
      self.state = state
    print(f'Using state {self.state}')
    y, state = self.layer(x, self.state, lengths)
    self.state = self._detach_state(state)
    return y, state

  def _detach_state(self, state):
    if isinstance(state, tuple):
      return tuple(s.detach() for s in state)
    if isinstance(state, list):
      return [s.detach() for s in state]
    return state.detach()

You could wrap any of the Haste RNN layers with this StatefulRNN decorator class and you'd get the stateful behavior you're looking for.

Here's a complete example:

import torch
import torch.nn as nn
import haste_pytorch as haste


class StatefulRNN(nn.Module):
  def __init__(self, layer):
    super().__init__()
    self.layer = layer
    self.state = None

  def reset_states(self):
    self.state = None

  def forward(self, x, state=None, lengths=None):
    if state is not None:
      self.state = state
    print(f'Using state {self.state}')
    y, state = self.layer(x, self.state, lengths)
    self.state = self._detach_state(state)
    return y, state

  def _detach_state(self, state):
    if isinstance(state, tuple):
      return tuple(s.detach() for s in state)
    if isinstance(state, list):
      return [s.detach() for s in state]
    return state.detach()


def reset_states(model):
  for layer in model.modules():
    if hasattr(layer, 'reset_states'):  # could use `isinstance` instead to check if it's an RNN layer
      layer.reset_states()


SEQ_LEN = 250
BATCH_SIZE = 10
INPUT_SIZE = 3
HIDDEN_SIZE = 5

lstm = haste.GRU(INPUT_SIZE, HIDDEN_SIZE)  # or haste.LSTM or ...
model = StatefulRNN(lstm)

learning_rate = 1e-3
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10):
  x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
  if t % 3 == 0:
    reset_states(model)
  y, _ = model(x)
  loss = loss_fn(y, torch.zeros_like(y))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  print(f'Step {t}, loss {loss}')

from haste.

sharvil avatar sharvil commented on May 27, 2024

From what I can tell, stateful in Keras-land is a flag that aids in implementing truncated backprop through time. It stores the last state of the RNN layer for each batch item and uses that as the initial state for the next batch that passes through the RNN. Is that correct? How does the state get reset when an entirely unrelated batch of data comes along?

from haste.

OverLordGoldDragon avatar OverLordGoldDragon commented on May 27, 2024

@sharvil Not a backprop mechanism; gradients do not flow between batches, including for stateful=True. The rest of your description is correct; details below.

The states are reset to zero per layer via layer.reset_states() - or for all layers in the model via model.reset_states(); if reset_states() isn't called, the states do not reset themselves to zero. States are built and initialized via an overridable method, get_initial_state() - e.g. tailored for LSTM (using another overridable method).

So basically, implementing requires building a dedicated tensor that captures the last timestep's hidden state, and is resettable to zero via a method.


When and how does LSTM "pass states" in stateful? Pasting relevant excerpts from here:

  • When: only batch-to-batch; samples are entirely independent
  • How: in Keras, only batch-sample to batch-sample: stateful=True requires you to specify batch_shape instead of input_shape - because, Keras builds batch_size separate states of the LSTM at compiling

Per above, this should not be done:

# sampleNM = sample N at timestep(s) M
batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample21, sample41, sample11, sample31]

This implies 21 causally follows 10 - and will wreck training. Instead do:

batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample11, sample21, sample31, sample41]

from haste.

OverLordGoldDragon avatar OverLordGoldDragon commented on May 27, 2024

As for how an "entirely unrelated batch" is detected in practice, can be done via callbacks w/ a counter in on_batch_end(); e.g. call reset_states() every 4th batch - or a custom train loop (my approach).

from haste.

sharvil avatar sharvil commented on May 27, 2024

Thanks so much for the detailed writeup. I think this is fairly easy to achieve in PyTorch with the existing Haste API. Here's some sample code that demonstrates what a stateful LSTM would look like:

import torch
import torch.nn as nn
import haste_pytorch as haste


class StatefulLSTM(nn.Module):
  def __init__(self, layer):
    super().__init__()
    self.layer = layer
    self.state = None

  def forward(self, x, state=None, lengths=None, reset=False):
    if reset:
      self.state = state
    print(f'Using state {self.state}')
    y, state = self.layer(x, self.state, lengths)
    self.state = (state[0].detach(), state[1].detach())
    return y, state


SEQ_LEN = 250
BATCH_SIZE = 1
INPUT_SIZE = 3
HIDDEN_SIZE = 5

lstm = haste.LSTM(INPUT_SIZE, HIDDEN_SIZE)
model = StatefulLSTM(lstm)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(10):
  x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
  y, _ = model(x, reset=(t % 3 == 0))
  loss = loss_fn(y, torch.zeros_like(y))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  print(f'Step {t}, loss {loss}')

This code trains an LSTM on sequences of length 750 which are fed in chunks of 250. The LSTM state is squirreled away in self.state of StatefulLSTM and is cleared whenever reset=True is passed in to forward(...).

Does this solution work for your application?

from haste.

OverLordGoldDragon avatar OverLordGoldDragon commented on May 27, 2024

Thanks for the demo. From what I understand from lstm.py, state=None triggers resetting the cell and hidden states to zero, which StatefulLSTM enables via an additional reset argument. Indeed, that's accurate. However, it's unclear how to reset states if the layers are part of a larger network (e.g. with conv layers). I don't suppose model.reset_states() is doable, unless modifying model class instance; a workaround is something like:

def reset_states(model):
    for layer in model.layers:  # tf.keras
        if hasattr(layer, 'reset_states'):
            layer.reset_states()

but I can't tell from your snippet how one would reset states before passing an input.

Also, it'd work best if stateful was part of the base implementation, rather than a dedicated one, else it'll complicate extending functionality (e.g. will also need StatefulLayerNormLSTM).

from haste.

OverLordGoldDragon avatar OverLordGoldDragon commented on May 27, 2024

@sharvil Looks excellent - a wrapper is a fair-enough alternative. To be sure all works as expected, I'll get back to you once I get haste_pytorch working.

from haste.

OverLordGoldDragon avatar OverLordGoldDragon commented on May 27, 2024

@sharvil Ran the script - looks good; I'll compare it more extensively vs. Keras later, but currently all seems to work as expected. Thank you.

from haste.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.