Giter Site home page Giter Site logo

Plans for RNN about onednn HOT 23 CLOSED

sbodenstein avatar sbodenstein commented on May 21, 2024
Plans for RNN

from onednn.

Comments (23)

emfomenk avatar emfomenk commented on May 21, 2024 3

Hi @taliesinb,

All the following changes are not finalized yet.
But may give a clue how the things will look like...

mkldnn_types.h:

/** A descriptor of an RNN operation. */
typedef struct {
    /** The kind of primitive. Used for self identifying the primitive
     * descriptor. Must be #mkldnn_rnn. */
    mkldnn_primitive_kind_t primitive_kind;
    /** The kind of propagation. Possible values: #mkldnn_forward_training,
     * #mkldnn_forward_inference, #mkldnn_backward_data,
     * and #mkldnn_backward_weights. */
    mkldnn_prop_kind_t prop_kind;
    /** The kind of the RNN algorithm. Possible values:
     * #mkldnn_rnn_relu, #mkldnn_rnn_tanh, #mkldnn_rnn_lstm, #mkldnn_rnn_gru. */
    mkldnn_alg_kind_t alg_kind;
    /** The direction of the RNN. Possible values:
     * #mkldnn_rnn_unidirectional, #mkldnn_rnn_bidirectional.*/
    mkldnn_rnn_direction_t direction;
    /** The input mode of the RNN. Possible values:
     * #mkldnn_rnn_linear_input, #mkldnn_rnn_skip_input.*/
    mkldnn_rnn_input_mode_t input_mode;
    /** The number of hidden states in one cell */
    size_t num_states;
    /** The number of layers in entire RNN network */
    size_t num_layers;
    /** The length of sequences in entire RNN network */
    size_t num_seqs;
    /** state and cell output in entire RNN network */
    int state_outputs;
    /** Input(x) memory descriptor. [seq, batch, input_size] */
    mkldnn_memory_desc_t x_desc;
    /** State input(hx) memory descriptor. [layer, batch, hidden_size] */
    mkldnn_memory_desc_t hx_desc;
    /** Output(y) memory descriptor. [seq, batch, hidden_size] */
    mkldnn_memory_desc_t y_desc;
    /** Weights memory descriptor. */
    mkldnn_memory_desc_t weights_desc;

    // @TODO check if we need dropout descriptor
} mkldnn_rnn_desc_t;

mkldnn.h:

/** @addtogroup c_api_rnn RNN (Including vanilla RNN, LSTM, GRU)
 * A primitive to compute RNN.
 * @{ */

/** Initializes an rnn descriptor @p rnn_desc for forward propagation using
 * @p prop_kind (possible values are #mkldnn_forward_training or
 * #mkldnn_forward_inference), @p alg_kind (possible values are
 * #mkldnn_rnn_relu, #mkldnn_rnn_tanh, #mkldnn_rnn_lstm or #mkldnn_rnn_gru),
 * @p direction (possible values are #mkldnn_rnn_unidirectional or
 * #mkldnn_rnn_bidirectional), @p input_mode for the input mode,
 * @p num_states for the number of hidden states, @p num_layers
 * for the number of stacked layers, @p num_seqs for the length of the
 * sequences, and memory descriptors */
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(
        mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
        mkldnn_alg_kind_t alg_kind, mkldnn_rnn_direction_t direction,
        mkldnn_rnn_input_mode_t input_mode, size_t num_states,
        size_t num_layers, size_t num_seqs, int state_outputs,
        const mkldnn_memory_desc_t *x_desc,
        const mkldnn_memory_desc_t *hx_desc,
        const mkldnn_memory_desc_t *y_desc,
        const mkldnn_memory_desc_t *weights_desc);

/** Initializes an rnn descriptor @p rnn_desc for backward propagation using
 * @p alg_kind (possible values are #mkldnn_rnn_relu, #mkldnn_rnn_tanh,
 * #mkldnn_rnn_lstm or #mkldnn_rnn_gru),
 * @p direction (possible values are #mkldnn_rnn_unidirectional or
 * #mkldnn_rnn_bidirectional), @p input_mode for the input mode,
 * @p num_states for the number of hidden states, @p num_layers
 * for the number of stacked layers, @p num_seqs for the length of the
 * sequences, and memory descriptors */
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(
        mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
        mkldnn_alg_kind_t alg_kind, mkldnn_rnn_direction_t direction,
        mkldnn_rnn_input_mode_t input_mode, size_t num_states,
        size_t num_layers, size_t num_seqs, int state_outputs,
        const mkldnn_memory_desc_t *x_desc,
        const mkldnn_memory_desc_t *hx_desc,
        const mkldnn_memory_desc_t *y_desc,
        const mkldnn_memory_desc_t *weights_desc);

/** @} */

from onednn.

apaszke avatar apaszke commented on May 21, 2024 1

I think at this point we're pretty much stuck with packed sequence/padded inputs in PyTorch, so it would be cool if you supported something similar. cuDNN API is quite good, except for weight format management. Please, unless it is absolutely necessary, don't require frameworks to give you weights as a single chunk of memory, and if this is needed, then at least define a format openly. Right now cuDNN's answer is "use our API to query where to put each weight", which is terribly inconvenient.

from onednn.

mgouicem avatar mgouicem commented on May 21, 2024 1

Thank you for the clarification and the input. We will take that into account when designing our API.

from onednn.

emfomenk avatar emfomenk commented on May 21, 2024

Hi @sbodenstein,

The work is in progress.
Hopefully RNN will be available soon.

from onednn.

sbodenstein avatar sbodenstein commented on May 21, 2024

Fantastic, this will be super useful!

from onednn.

fightbob avatar fightbob commented on May 21, 2024

Hi @emfomenk, when RNN feature will be released???

from onednn.

sbodenstein avatar sbodenstein commented on May 21, 2024

@emfomenk: is the planned RNN API going to be compatible with the cuDNN version?

from onednn.

emfomenk avatar emfomenk commented on May 21, 2024

Hi @fightbob and @sbodenstein,

RNN is slightly postponed -- some other urgent stuff appeared...
Unfortunately no ETA at the moment :( I will ping guys for the latest status and get back to you.

Yeah, API is going to be very close to cuDNN one.
We encountered some problems with how it would look like in C++ API (there too many constructors there for different configuration), but C API should be pretty straightforward.

from onednn.

taliesinb avatar taliesinb commented on May 21, 2024

Great. If you can drop any details about how the API might differ before you actually ship, that would be helpful to us for planning purposes.

from onednn.

taliesinb avatar taliesinb commented on May 21, 2024

@emfomenk Thanks so much, that's very useful to know!

from onednn.

piiswrong avatar piiswrong commented on May 21, 2024

Any updates on this?

from onednn.

sbodenstein avatar sbodenstein commented on May 21, 2024

@emfomenk: will there be support for variable-length sequences (ie. a batch of sequences with different lengths)? cuDNN has support, but don't see this in the above design.

from onednn.

taliesinb avatar taliesinb commented on May 21, 2024

specifically, the concern is that because NVIDIA's design just outputs the final cell state (and not a sequence of cell states), you cannot accomplish variable length support after-the-fact, because all cell states corresponding to inputs that don't have the full batch length will be invalid. and so we simply can't use the optimization at all for variable-length problems unless it bakes variable length support into the design.

from onednn.

ykim362 avatar ykim362 commented on May 21, 2024

@taliesinb @sbodenstein The current design can output all the outputs(h) at the last stack and all the cell state(c) at the last time seq. But, do you need all the cell states in the middle of the sequences?

from onednn.

taliesinb avatar taliesinb commented on May 21, 2024

@ykim362 yes, but you have a choice. if the RNN layer wants to support variable-length operation†, it can either:

  1. provide the entire history of cell states, so that the correct per-element last time step can be selected for each batch element, e.g. by using MXNet's SequenceLast layer

  2. accept another input that contains the sequence lengths and the then expect the sequences to be densely packed (this is what cuDNN does), this way the cell state output is already correct.

† to be clear what I mean by variable-length operation, I'm referring to the case where you have a batch that contains multiple unequal sequence lengths in it -- and most sequence problems are like this. older frameworks just pad the shorter sequences with zeros and expect the net work to learn to deal with the zeros, but this fundamentally changes the problem. by far the better approach is to pad with junk, and carefully make sure that you take the 'correct' outputs and states from just before the junk using pick operations etc. we want to make sure that the MKL implementation makes this possible. Option 1 does the pick externally, option 2 does the pick internally.

from onednn.

mgouicem avatar mgouicem commented on May 21, 2024

@taliesinb @sbodenstein Thanks for the comments! I think I am missing something...

I understand that cudnn current interface enables option 1 (parameters yDesc and y in cudnn doc)

For option 2, I did not find any documentation related to that. The only element I see to accommodate for variable length in cudnn API is that the inputs for each time step can have different minibatch (in decreasing order). I guess this assumes that the user has to sort the sequences in the minibatch first (e.g. input from longest sequence first in each minibatch), but there is not much details in their doc. Could you please elaborate on the use case?

from onednn.

sbodenstein avatar sbodenstein commented on May 21, 2024

I understand that cudnn current interface enables option 1 (parameters yDesc and y in cudnn doc)

This is correct for GRU and standard RNN, but untrue for LSTM, which has a second state (cell state) that is not returned in y. You have to use Option 2 in cuDNN. Also, for bidirectional, only Option 2 will work.

I guess this assumes that the user has to sort the sequences in the minibatch first

The framework/user has to indeed sort the sequences by length, and pack them. This is annoying, and would be good if the Intel version could avoid it.

Could you please elaborate on the use case?

Frameworks that support variable length RNNs require this (eg PyTorch pytorch/pytorch#873), and we wish to add this support to MXNet as well. Including @apaszke and @jekbradbury, as this discussion about the MKL RNN design seems very relevant for variable length RNNs in PyTorch as well (I think PyTorch will also want to use this MKL RNN implementation).

from onednn.

taliesinb avatar taliesinb commented on May 21, 2024

The framework/user has to indeed sort the sequences by length, and pack them. This is annoying, and would be good if the Intel version could avoid it.

Or better yet provide that as an optional feature.

from onednn.

sbodenstein avatar sbodenstein commented on May 21, 2024

@mgouicem: for us, the cleanest approach to supporting variable length sequences is a bit different to cuDNN approach. The approach is:

  • Accept an extra input to mkldnn_rnn_forward_desc_init that accepts a list of sequence lengths at runtime
  • The input can either be the usual shape {batch, max seq len, feature size} and then interpret values outside the sequence lengths provided at runtime as padding
  • Or the input is like cuDNN, effectively a packed piece of memory.

from onednn.

taliesinb avatar taliesinb commented on May 21, 2024

Right now cuDNN's answer is "use our API to query where to put each weight", which is terribly inconvenient.

Yeah it really sucks. Made compilation so much more complicated for us. And the practice of not publicly defining properties, sizes, etc and putting them behind an API makes scratch memory much harder to share across buckets because the workspace size cannot be decided without querying CUDA at compile time, which MXNet does not support.

EDIT: clarify my complaint.

from onednn.

xhzhao avatar xhzhao commented on May 21, 2024

@apaszke @taliesinb totally agree with you about the cudnn weight format, and i think the clear weight format is very important for the framework and users.

from onednn.

BenjaminJurke avatar BenjaminJurke commented on May 21, 2024

Is there any update on the timeline for the release of the RNN primitives at this point? Just curious, but very much looking forward to it.

from onednn.

mgouicem avatar mgouicem commented on May 21, 2024

@BenjaminJurke , unfortunately no precise timeline for the feature yet, but we are working on it.

from onednn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.