Giter Site home page Giter Site logo

Comments (2)

albertz avatar albertz commented on May 24, 2024

Hi,

When you run the training (or inference; but there you will get a somewhat different output), in the output (because of the option debug_print_layer_output_template), you will get the exact information about the shapes. It will look sth like this:

layer root/'data' output: Data(name='data', shape=(None, 40))
layer root/'source' output: Data(name='source_output', shape=(None, 40))
layer root/'lstm0_fw' output: Data(name='lstm0_fw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm0_bw' output: Data(name='lstm0_bw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm1_pool' output: Data(name='lstm1_pool_output', shape=(None, 2048))
layer root/'lstm2_fw' output: Data(name='lstm2_fw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm2_bw' output: Data(name='lstm2_bw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm2_pool' output: Data(name='lstm2_pool_output', shape=(None, 2048))
layer root/'lstm3_fw' output: Data(name='lstm3_fw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm3_bw' output: Data(name='lstm3_bw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm3_pool' output: Data(name='lstm3_pool_output', shape=(None, 2048), batch_dim_axis=1)
layer root/'lstm4_fw' output: Data(name='lstm4_fw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm4_bw' output: Data(name='lstm4_bw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm4_pool' output: Data(name='lstm4_pool_output', shape=(None, 2048), batch_dim_axis=1)
layer root/'lstm5_fw' output: Data(name='lstm5_fw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'lstm5_bw' output: Data(name='lstm5_bw_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'encoder' output: Data(name='encoder_output', shape=(None, 2048), batch_dim_axis=1)
layer root/'ctc' output: Data(name='ctc_output', shape=(None, 10026), batch_dim_axis=1)
layer root/'enc_ctx' output: Data(name='enc_ctx_output', shape=(None, 1024), batch_dim_axis=1)
layer root/'inv_fertility' output: Data(name='inv_fertility_output', shape=(None, 1), batch_dim_axis=1)
layer root/'enc_value' output: Data(name='enc_value_output', shape=(None, 1, 2048), batch_dim_axis=1)
layer root/'output' output: Data(name='output_output', shape=(None,), dtype='int32', sparse=True, dim=10025, batch_dim_axis=1)
Rec layer sub net:
  Input layers moved out of loop: (#: 2)
    output
    target_embed
  Output layers moved out of loop: (#: 3)
    output_prob
    readout
    readout_in
  Layers in loop: (#: 10)
    att
    att0
    att_weights
    energy
    energy_tanh
    energy_in
    s_transformed
    s
    weight_feedback
    accum_att_weights
  Unused layers: (#: 1)
    end
layer root/output:rec-subnet-input/'output' output: Data(name='classes', shape=(None,), dtype='int32', sparse=True, dim=10025)
layer root/output:rec-subnet-input/'target_embed' output: Data(name='target_embed_output', shape=(None, 621))
layer root/output:rec-subnet/'weight_feedback' output: Data(name='weight_feedback_output', shape=(None, 1024), time_dim_axis=None)
layer root/output:rec-subnet/'prev:target_embed' output: Data(name='target_embed_output', shape=(621,), time_dim_axis=None)
layer root/output:rec-subnet/'s' output: Data(name='s_output', shape=(1000,), time_dim_axis=None)
layer root/output:rec-subnet/'s_transformed' output: Data(name='s_transformed_output', shape=(1024,), time_dim_axis=None)
layer root/output:rec-subnet/'energy_in' output: Data(name='energy_in_output', shape=(None, 1024), batch_dim_axis=1)
layer root/output:rec-subnet/'energy_tanh' output: Data(name='energy_tanh_output', shape=(None, 1024), batch_dim_axis=1)
layer root/output:rec-subnet/'energy' output: Data(name='energy_output', shape=(None, 1), batch_dim_axis=1)
layer root/output:rec-subnet/'att_weights' output: Data(name='att_weights_output', shape=(None, 1), batch_dim_axis=1)
layer root/output:rec-subnet/'att0' output: Data(name='att0_output', shape=(1, 2048), time_dim_axis=None)
layer root/output:rec-subnet/'att' output: Data(name='att_output', shape=(2048,), time_dim_axis=None)
layer root/output:rec-subnet/'accum_att_weights' output: Data(name='accum_att_weights_output', shape=(None, 1), time_dim_axis=None)
layer root/output:rec-subnet-output/'s' output: Data(name='s_output', shape=(None, 1000), batch_dim_axis=1)
layer root/output:rec-subnet-output/'prev:target_embed' output: Data(name='target_embed_output', shape=(None, 621), batch_dim_axis=1)
layer root/output:rec-subnet-output/'att' output: Data(name='att_output', shape=(None, 2048), batch_dim_axis=1)
layer root/output:rec-subnet-output/'readout_in' output: Data(name='readout_in_output', shape=(None, 1000), batch_dim_axis=1)
layer root/output:rec-subnet-output/'readout' output: Data(name='readout_output', shape=(None, 500), batch_dim_axis=1)
layer root/output:rec-subnet-output/'output_prob' output: Data(name='output_prob_output', shape=(None, 10025), batch_dim_axis=1)
layer root/'decision' output: Data(name='output_output', shape=(None,), dtype='int32', sparse=True, dim=10025, batch_dim_axis=1)

Note that the shape information here is slightly confusing: the batch-dim is excluded here in the output. If batch_dim_axis is not specified in the output, it means it is 0 (the default).
All the layers which are inside the recurrent loop of the output layer, i.e. layer root/output:rec-subnet/ of the output, specify the shape of the data inside the loop. Accumulating this data would add the decoder time dim in the 0th axis. So e.g. this layer:
layer root/output:rec-subnet/'weight_feedback' output: Data(name='weight_feedback_output', shape=(None, 1024), time_dim_axis=None)
means that it is of shape (batch, encoder-time, 1024) (time_dim_axis=None just means that it doesn't interpret it as a time-dim for some reason, but that doesn't really matter), inside the decoder loop, i.e. accumulating it would end up in (decoder-time, batch, encoder-time, 1024).

The comment in the code about enc_value is wrong. Maybe some initial version of SplitDimsLayer enforced batch-major (I don't remember anymore...). But it doesn't matter too much as every layer will check what kind of input it gets, and if it wants to operate with it in some format, it will always automatically convert it. You see in this output:
layer root/'enc_value' output: Data(name='enc_value_output', shape=(None, 1, 2048), batch_dim_axis=1)
that the format is (encoder-time,batch,1,2048).

These output format info (via debug_print_layer_output_template) is build at construction time. There are a few checks to enforce them, so normally you can rely on this information.
There is also debug_print_layer_output_shape, which will add tf.Print to really print the actual shape at runtime. But while debug_print_layer_output_template can always be enabled, as it doesn't really hurt in any way, debug_print_layer_output_shape will add lots of noise for every single session.run step.

from returnn-experiments.

iankur avatar iankur commented on May 24, 2024

Thanks for taking out your time to make such a detailed explanation. It makes the entire thing very clear.

from returnn-experiments.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.