Giter Site home page Giter Site logo

dac-jax's Introduction

DAC-JAX

Descript Audio Codec (.dac) is a high-fidelity general neural audio codec introduced in the paper
"High-Fidelity Audio Compression with Improved RVQGAN".

This repository is an unofficial JAX implementation of the PyTorch-based DAC and has no affiliation with Descript.

You can read the DAC-JAX paper here.

Usage

Installation

  1. Upgrade pip and setuptools:

    pip install --upgrade pip setuptools
  2. Install the CPU version of PyTorch. We strongly suggest the CPU version because trying to install a GPU version can conflict with JAX's CUDA-related installation.

  3. Install the CPU version of TensorFlow

    pip install tensorflow-cpu
  4. Install JAX (with GPU support).

  5. Install DAC-JAX with one of the following:

    pip install git+https://github.com/DBraun/DAC-JAX
    

    Or,

    python -m pip install .

    Or, if you intend to contribute, clone and do an editable install:

    python -m pip install -e ".[dev]"

Weights

The original Descript repository releases model weights under the MIT license. These weights are for models that natively support 16 kHz, 24kHz, and 44.1kHz sampling rates. Our scripts download these PyTorch weights and load them into JAX. Weights are automatically downloaded when you first run an encode or decode command. You can download them in advance with one of the following commands:

python -m dac_jax download_model # downloads the default 44kHz variant
python -m dac_jax download_model --model_type 44khz --model_bitrate 16kbps # downloads the 44kHz 16 kbps variant
python -m dac_jax download_model --model_type 44khz # downloads the 44kHz variant
python -m dac_jax download_model --model_type 24khz # downloads the 24kHz variant
python -m dac_jax download_model --model_type 16khz # downloads the 16kHz variant

The default download location is ~/.cache/dac_jax. You can change the location by setting an absolute path value for an environment variable DAC_JAX_CACHE. For example, on macOS/Linux:

export DAC_JAX_CACHE=/Users/admin/my-project/dac_jax_models

If you do this, remember to still have DAC_JAX_CACHE set before you use the load_model function.

Compress audio

python -m dac_jax encode /path/to/input --output /path/to/output/codes

This command will create .dac files with the same name as the input files. It will also preserve the directory structure relative to input root and re-create it in the output directory. Please use python -m dac_jax encode --help for more options.

Reconstruct audio from compressed codes

python -m dac_jax decode /path/to/output/codes --output /path/to/reconstructed_input

This command will create .wav files with the same name as the input files. It will also preserve the directory structure relative to input root and re-create it in the output directory. Please use python -m dac_jax decode --help for more options.

Programmatic usage

Here we use DAC-JAX as a "bound" module, freeing us from repeatedly passing variables as an argument and using .apply. Note that bound modules are not meant to be used in fine-tuning.

import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear

import jax.numpy as jnp
import librosa

# Download a model and bind variables to it.
model, variables = dac_jax.load_model(model_type="44khz")
model = model.bind(variables)

# Load a mono audio file
signal, sample_rate = librosa.load('input.wav', sr=44100, mono=True, duration=.5)

signal = jnp.array(signal, dtype=jnp.float32)
while signal.ndim < 3:
    signal = jnp.expand_dims(signal, axis=0)

target_db = -16  # Normalize audio to -16 dB
x, input_db = volume_norm(signal, target_db, sample_rate)

# Encode audio signal as one long file (may run out of GPU memory on long files)
x = model.preprocess(x, sample_rate)
z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)

# Decode audio signal
y = model.decode(z, length=signal.shape[-1])

# Undo previous loudness normalization
y = y * db2linear(input_db - target_db)

# Calculate mean-square error of reconstruction in time-domain
mse = jnp.square(y-signal).mean()

Compression with constant GPU memory regardless of input length:

import dac_jax

import jax
import jax.numpy as jnp
import librosa

# Download a model and set padding to False because we will use the JIT-functions.
model, variables = dac_jax.load_model(model_type="44khz", padding=False)

# Load a mono audio file
signal, sample_rate = librosa.load('input.wav', sr=44100, mono=True, duration=.5)

signal = jnp.array(signal, dtype=jnp.float32)
while signal.ndim < 3:
    signal = jnp.expand_dims(signal, axis=0)

# Jit-compile these functions because they're used inside a loop over chunks.
@jax.jit
def compress_chunk(x):
    return model.apply(variables, x, method='compress_chunk')

@jax.jit
def decompress_chunk(c):
    return model.apply(variables, c, method='decompress_chunk')

win_duration = 0.5  # Adjust based on your GPU's memory size
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)

# Save and load to and from disk
dac_file.save("compressed.dac")
dac_file = dac_jax.DACFile.load("compressed.dac")

# Decompress it back to audio
y = model.decompress(decompress_chunk, dac_file)

Training

The baseline model configuration can be trained using the following commands.

python scripts/train.py --args.load conf/base.yml

In root directory, monitor with Tensorboard (runs will appear next to scripts):

tensorboard --logdir ./runs

Testing

python -m pytest tests

Limitations

Pull requests—especially ones which address any of the limitations below—are welcome.

  • PyTorch is required as a dependency because it's used to load model weights before converting to JAX. As long as this is required, we recommend installing the CPU version of PyTorch.
  • We hope to remove audiotools as a dependency since it's a PyTorch library.
  • We implement the "chunked" compress/decompress methods from the PyTorch repository, although this technique has some problems outlined here.
  • We have not performed a full training run, although we have tested the train scripts, loss functions and eval metrics.
  • If you want to perform a training run—especially one which reproduces the original paper—you should examine any "todo" in train.py or in the config files.
  • We have not run all evaluation scripts in the scripts directory.
  • The model architecture code (model/dac.py) has many static methods to help with finding DAC's delay and output_length. Please help us refactor this so that code is not so duplicated and at risk of typos.
  • In audio_utils.py we use DM_AUX's STFT function instead of jax.scipy.signal.stft.
  • The source code of DAC-JAX has several todo: markings which indicate (mostly minor) improvements we'd like to have.
  • We don't have a Docker image yet like the original DAC repository does.
  • Please check the limitations of argbind.

Citation

If you use DAC-JAX in your work, please cite the original DAC:

@article{kumar2024high,
  title={High-fidelity audio compression with improved rvqgan},
  author={Kumar, Rithesh and Seetharaman, Prem and Luebs, Alejandro and Kumar, Ishaan and Kumar, Kundan},
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2024}
}

and DAC-JAX:

@misc{braun2024dacjax,
  title={{DAC-JAX}: A {JAX} Implementation of the Descript Audio Codec}, 
  author={David Braun},
  year={2024},
  eprint={2405.11554},
  archivePrefix={arXiv},
  primaryClass={cs.SD}
}

dac-jax's People

Contributors

dbraun avatar

Stargazers

Anatol Myshkin avatar Dinghao Zhou avatar Emmanuel Infante avatar Glenn 'devalias' Grant avatar  avatar 152334H avatar  avatar Slice avatar Tim Pulver avatar Christian J. Steinmetz avatar Youqiang Zheng avatar  avatar andychenml avatar Ivan Meresman Higgs avatar yohawing avatar Yuan-Man avatar James Parsloe avatar

Watchers

 avatar Anatol Myshkin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.