Giter Site home page Giter Site logo

ricklentz / transducer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from awni/transducer

0.0 1.0 0.0 166 KB

A Fast Sequence Transducer Implementation with PyTorch Bindings

License: Apache License 2.0

Python 22.09% C++ 30.84% CMake 20.85% C 0.93% Cuda 25.29%

transducer's Introduction

transducer

A fast RNN-Transducer implementation on the CPU and GPU (CUDA) with python bindings and a PyTorch extension. The RNN-T loss function was published in Sequence Transduction with Recurrent Neural Networks.

The code has been tested with Python 3.9 and PyTorch 1.9.

Install and Test

To install from the top level of the repo run:

python setup.py install

To use the PyTorch extension, install PyTorch and test with:

python torch_test.py

Usage

The easiest way to use the transducer loss is with the PyTorch bindings:

criterion = transducer.TransducerLoss()
loss = criterion(emissions, predictions, labels, input_lengths, label_lengths)

The loss will run on the same device as the input tensors. For more information, see the criterion documentation.

To get the "teacher forced" best path:

predicted_labels = criterion.viterbi(emissions, predictions, input_lengths, label_lengths)

Memory Use and Benchmarks

The transducer is designed to be much lighter in memory use. Most implementations use memory which scales with the product B * T * U * V (where B is the batch size, T is the maximum input length in the batch, U is the maximum output length in the batch, and V is the token set size). The memory of this implementation scales with the product B * T * U and does not increase with the token set size. This is particularly important for the large token set sizes commonly used with word pieces. (NB In this implementation you cannot use a "joiner" network to connect the outputs of the transcription and prediction models. The algorithm hardcodes the fact that these are additively combined.)

Performance benchmarks for the CUDA version running on an A100 GPU are below. We compare to the Torch Audio RNN-T loss which was also run on the same A100 GPU. An entry of "OOM" means the implementation ran out of memory (in this case 20GB).

Times are reported in milliseconds.

T=2000, U=100, B=8

V Transducer Torch Audio
100 8.18 139.26
1000 13.64 OOM
2000 18.83 OOM
10000 59.18 OOM

T=2000, U=100, B=32

V Transducer Torch Audio
100 20.58 555.00
1000 38.42 OOM
2000 58.19 OOM
10000 223.33 OOM

transducer's People

Contributors

awni avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.