Comments (4)
Sorry and thank you. We stopped using Haste because in our case it was slower than PyTorch and it didn't provide a significant enough performance increase for us to justify the additional overhead.
PyTorch API does support bidirectional RNNs, and we use them.
In order to support reverse_sequence
, we used roll
and flip
. We implemented our own roll
function, see here:
def roll(tensor, shift, dim=-1):
""" Shift a tensor along the specified dimension.
TODO: Create a `Roll` module so that `indices` are not recomputed each time.
Args:
tensor (torch.Tensor [*, dim, *]): The tensor to shift.
shift (torch.Tensor [*]): The number of elements to shift `dim`. This tensor must have one
less dimensions than `tensor`.
dim (int): The dimension to shift.
Returns:
tensor (torch.Tensor [*, dim, *]): The tensor that was shifted.
"""
shift = shift.unsqueeze(dim)
assert shift.dim() == tensor.dim(
), 'The `shift` tensor must be the same size as `tensor` without the `dim` dimension.'
indices = torch.arange(0, tensor.shape[dim], device=tensor.device)
dim = tensor.dim() + dim if dim < 0 else dim
# EXAMPLE:
# indicies.shape == (3,)
# tensor.shape == (1, 2, 3, 4, 5)
# indices_shape == [1, 1, 3, 1, 1]
indices_shape = [1] * dim + [-1] + [1] * (tensor.dim() - dim - 1)
indices = indices.view(*tuple(indices_shape)).expand(*tensor.shape)
indices = (indices - shift) % tensor.shape[dim]
return torch.gather(tensor, dim, indices)
PyTorch does not provide a lengths
parameter; therefore, we needed to implement masking on our own in order to use the PyTorch LSTM
. Thanks for providing that mechanism!
We were using LSTMCell
with unrolled sequences; therefore, we needed to convert our lengths
to binary masks for each call.
from haste.
Furthermore, can this be used like an LSTMCell
(https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell) with the lengths parameter being used more like a binary mask?
from haste.
The PyTorch API doesn't provide bidirectional RNNs out of the box (the TensorFlow API does). Unfortunately, PyTorch doesn't have a reverse_sequence
function so you'll either need to write your own or use the implementation here: pytorch/pytorch#1794.
Suppose you have reverse_sequence
. Then the bidirectional RNN would look something like this:
reversed_inputs = reverse_sequence(inputs, lengths)
rnn_fwd = LSTM(...)
rnn_bwd = LSTM(...)
y_fwd, _ = rnn_fwd(inputs, lengths)
y_bwd, _ = rnn_bwd(reversed_inputs, lengths)
y_bwd = reverse_sequence(y_bwd, lengths)
y = concatenate([y_fwd, y_bwd], dim=-1)
I'm not sure what you mean by the lengths
parameter being used like a binary mask. AFAICT, the standard PyTorch LSTMCell
doesn't accept a lengths
parameter. If you have a one-hot vector indicating end-of-sequence, you can just just do argmax(one_hot) + 1
to convert to lengths
.
from haste.
Closing due to inactivity.
from haste.
Related Issues (20)
- Install on pip on systems without cuda HOT 7
- Segmentation fault on Cuda 10.0 HOT 2
- Support zoneout on lstm cell state and add recurrent dropout HOT 2
- CUDA error: an illegal memory access was encountered HOT 6
- haste_pytorch: Gradient for kernel/recurrent_kernel becomes zero when trained on gpu HOT 4
- How to expose LayerNormGRUCell to python ? HOT 2
- Can't run haste layers in Keras HOT 12
- Biases in final IndRNN layer are 0 HOT 1
- Zoneout remains during eval() HOT 2
- return_state_sequence for tf version
- layer_norm_gru_cell HOT 1
- Can Bidirectional Rnn and multi-layer Rnn be supported? HOT 1
- Activation function in IndRNN HOT 1
- haste_pytorch does not install properly with conda cudatoolkit? HOT 3
- Feature request for cell classes for pytorch HOT 7
- `RNN`s with `zoneout > 0.0` have wrong gradients HOT 1
- haste_tf compilation fails with "‘bfloat16’ in namespace ‘Eigen’ does not name a type"
- Support for PyTorch packed sequences HOT 2
- Supporting RWKV (a RNN that can match transformer LM & zero-shot performance at 1B+ params)
- Nan loss when replace pytorch LSTM with your LSTM or LayerNormLSTM HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from haste.