Giter Site home page Giter Site logo

frito's Introduction

FRITO: FREQUENCY-REGULARIZED TRANSFOR

This is the implementation for Improving Domain Generalization for Sound Classification with Sparse Frequency-Regularized Transformer

We propose FRITO, an effective regularization technique on Transformer's self-attention, to improve the model's generalization ability by limiting each sequence position's attention receptive field along the frequency dimension on the spectrogram. Experiments show that our method helps Transformer models achieve SOTA generalization performance on TAU 2020 and Nsynth datasets while saving 20% inference time.

Our scheme includes local and global attention, similar to the efficient transformers. This method differs from previous work in that it restricts the receptive field along the frequency dimension, aiming to improve the model's generalization ability instead of dealing with long sequences.

Setting up the experiments environment

This repo uses forked versions of sacred for configuration and logging, and pytorch-lightning for training.

For setting up Mamba is recommended and faster then conda:

conda install mamba -n base -c conda-forge

Now you can import the environment from environment.yml

mamba env create -f environment.yml

Now you have an environment named ba3l. Now install the forked versions of sacred and pl-lightning and ba3l.

# dependencies
conda activate ba3l
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=ba3l'
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=pytorch-lightning'
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=sacred' 

Getting started

Each dataset has an experiment file such as ex_dcase20.py and ex_nsynth.py and a dataset folder with a readme file. In general, you can prob the experiment file for help:

python ex_dcase20.py help

you can override any of the configuration using the sacred syntax. In order to see the available options either use omniboard or use:

 python ex_dcase20.py print_config

There are many pre-defined configuration options in config_updates.py. These include different models, setups etc... You can list these configurations with:

python ex_dcase20.py print_named_configs

There are many things that can be updated from the command line. In short:

  • All the configuration options under trainer are pytorch lightning trainer api. For example, to turn off cuda benchmarking add trainer.benchmark=False to the command line.
  • models.net are the FRITO (or the chosen NN) options.
  • models.mel are the preprocessing options (mel spectrograms).

Training on dcase20

Download and prepare the dataset as explained in the TAU2020. The base FRITO model can be trained using:

python ex_dcase20_dev.py with models.net.rf_norm_t=row_overlap_8 use_mixup=True mixup_alpha=0.3 trainer.use_tensorboard_logger=True -p --debug

where models.net.rf_norm_t can be set to row_overlap_{v} or per_row_{r}, which corresponds to the overlap factor $v$ and row cluster size $r$ in the paper, respectively.

Multi-gpu training can be enabled by setting the environment variable DDP, for example with 2 gpus:

DDP=2 CUDA_VISIBLE_DEVICES=0,1 python ex_dcase20.py with models.net.rf_norm_t=high_low_branch trainer.use_tensorboard_logger=True -p --debug

Training on nsynth

Download and prepare the dataset as explained in the nsynth. The base FRITO model can be trained using:

python ex_nsynth.py with models.net.rf_norm_t=high_low trainer.use_tensorboard_logger=True -p --debug

which is equivalent to:

python ex_nsynth.py with models.net.rf_norm_t=per_row_6 trainer.use_tensorboard_logger=True -p --debug

where models.net.rf_norm_t can be set to row_overlap_{v} or per_row_{r}, which corresponds to overlap factor $v$ and row cluster size $r$ in the paper, respectively.

Sparse Attention

Specifically, we implement models.net.rf_norm_t=sparse_row_{r} as the sparse version of models.net.rf_norm_t=per_row_{r}.

python ex_nsynth.py with models.net.rf_norm_t=sparse_row_6 trainer.use_tensorboard_logger=True -p --debug

Contact

The repo will be updated, in the mean time if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.