Giter Site home page Giter Site logo

george0828zhang / torch_cif Goto Github PK

View Code? Open in Web Editor NEW
29.0 3.0 3.0 171 KB

A fast parallel PyTorch implementation of the "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.

License: MIT License

Python 100.00%
pytorch cif continuous-integrate-and-fire alignment asr automatic-speech-recognition monotonic speech speech-recognition speech-to-text speech-translation torch

torch_cif's Introduction

torch-cif

A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.

Installation

PyPI

pip install torch-cif

Locally

git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install

Usage

def cif_function(
    inputs: Tensor,
    alpha: Tensor,
    beta: float = 1.0,
    tail_thres: float = 0.5,
    padding_mask: Optional[Tensor] = None,
    target_lengths: Optional[Tensor] = None,
    eps: float = 1e-4,
    unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
    r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
    https://arxiv.org/abs/1905.11235

    Shapes:
        N: batch size
        S: source (encoder) sequence length
        C: source feature dimension
        T: target sequence length

    Args:
        inputs (Tensor): (N, S, C) Input features to be integrated.
        alpha (Tensor): (N, S) Weights corresponding to each elements in the
            inputs. It is expected to be after sigmoid function.
        beta (float): the threshold used for determine firing.
        tail_thres (float): the threshold for determine firing for tail handling.
        padding_mask (Tensor, optional): (N, S) A binary mask representing
            padded elements in the inputs. 1 is padding, 0 is not.
        target_lengths (Tensor, optional): (N,) Desired length of the targets
            for each sample in the minibatch.
        eps (float, optional): Epsilon to prevent underflow for divisions.
            Default: 1e-4
        unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.

    Returns -> Dict[str, List[Tensor]]: Key/values described below.
        cif_out: (N, T, C) The output integrated from the source.
        cif_lengths: (N,) The output length for each element in batch.
        alpha_sum: (N,) The sum of alpha for each element in batch.
            Can be used to compute the quantity loss.
        delays: (N, T) The expected delay (in terms of source tokens) for
            each target tokens in the batch.
        tail_weights: (N,) During inference, return the tail.
        scaled_alpha: (N, S) alpha after applying weight scaling.
        cumsum_alpha: (N, S) cumsum of alpha after scaling.
        right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
        right_weights: (N, S) right scatter weights.
        left_indices: (N, S) left scatter indices.
        left_weights: (N, S) left scatter weights.
    """

Note

  • This implementation uses cumsum and floor to determine the firing positions, and use scatter to merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence (0.4, 1.8, 1.2, 1.2, 1.4)

drawing

  • Runing test requires pip install hypothesis expecttest.
  • If beta != 1, our implementation slightly differ from Algorithm 1 in the paper [1]:
    • When a boundary is located, the original algorithm add the last feature to the current integration with weight 1 - accumulation (line 11 in Algorithm 1), which causes negative weights in next integration when alpha < 1 - accumulation.
    • We use beta - accumulation, which means the weight in next integration alpha - (beta - accumulation) is always positive.
  • Feel free to contact me if there are bugs in the code.

References

  1. CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition
  2. Exploring Continuous Integrate-and-Fire for Adaptive Simultaneous Speech Translation

torch_cif's People

Contributors

george0828zhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

torch_cif's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.