Giter Site home page Giter Site logo

mqtts's Introduction

MQTTS

  • Official implementation for the paper: A Vector Quantized Approach for Text to Speech Synthesis on Real-World Spontaneous Speech.
  • Audio samples (40 each system) can be accessed at here.
  • Quick demo can be accessed here (Some are still TODO).
  • Paper appendix is here.

Setup the environment

  1. Setup conda environment:
conda create --name mqtts python=3.9
conda activate mqtts
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
pip install -r requirements.txt

(Update) You may need to create an access token to use the speaker embedding of pyannote as they updated their policy. If that's the case follow the pyannote repo and change every Inference("pyannote/embedding", window="whole") accordingly.

  1. Download the pretrained phonemizer checkpoint
wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt

Preprocess the dataset

  1. Get the GigaSpeech dataset from the official repo
  2. Install FFmpeg, then
conda install ffmpeg=4.3=hf484d3e_0
conda update ffmpeg
  1. Run python script
python preprocess.py --giga_speech_dir GIGASPEECH --outputdir datasets 

Train the quantizer and inference

  1. Train
cd quantizer/
python train.py --input_wavs_dir ../datasets/audios \
                --input_training_file ../datasets/training.txt \
                --input_validation_file ../datasets/validation.txt \
                --checkpoint_path ./checkpoints \
                --config config.json
  1. Inference to get codes for training the second stage
python get_labels.py --input_json ../datasets/train.json \
                     --input_wav_dir ../datasets/audios \
                     --output_json ../datasets/train_q.json \
                     --checkpoint_file ./checkpoints/g_{training_steps}
python get_labels.py --input_json ../datasets/dev.json \
                     --input_wav_dir ../datasets/audios \
                     --output_json ../datasets/dev_q.json \
                     --checkpoint_file ./checkpoints/g_{training_steps}

Train the transformer (below an example for the 100M version)

cd ..
mkdir ckpt
python train.py \
     --distributed \
     --saving_path ckpt/ \
     --sampledir logs/ \
     --vocoder_config_path quantizer/checkpoints/config.json \
     --vocoder_ckpt_path quantizer/checkpoints/g_{training_steps} \
     --datadir datasets/audios \
     --metapath datasets/train_q.json \
     --val_metapath datasets/dev_q.json \
     --use_repetition_token \
     --ar_layer 4 \
     --ar_ffd_size 1024 \
     --ar_hidden_size 256 \
     --ar_nheads 4 \
     --speaker_embed_dropout 0.05 \
     --enc_nlayers 6 \
     --dec_nlayers 6 \
     --ffd_size 3072 \
     --hidden_size 768 \
     --nheads 12 \
     --batch_size 200 \
     --precision bf16 \
     --training_step 800000 \
     --layer_norm_eps 1e-05

You can view the progress using:

tensorboard --logdir logs/

Run batched inference

You'll have to change speaker_to_text.json, it's just an example.

mkdir infer_samples
CUDA_VISIBLE_DEVICES=0 python infer.py \
    --phonemizer_dict_path en_us_cmudict_forward.pt \
    --model_path ckpt/last.ckpt \
    --config_path ckpt/config.json \
    --input_path speaker_to_text.json \
    --outputdir infer_samples \
    --batch_size {batch_size} \
    --top_p 0.8 \
    --min_top_k 2 \
    --max_output_length {Maximum Output Frames to prevent infinite loop} \
    --phone_context_window 3 \
    --clean_speech_prior

Pretrained checkpoints

  1. Quantizer (put it under quantizer/checkpoints/): here

  2. Transformer (100M version) (put it under ckpt/): model, config

mqtts's People

Contributors

b04901014 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mqtts's Issues

Error running inference from pre-trained models: inp = torch.cat([spkr, inp], 1) RuntimeError: Tensors must have same number of dimensions: got 4 and 3

Full output from running batch inference with pre-trained models:

_IncompatibleKeys(missing_keys=[], unexpected_keys=['vocoder.quantizer.quantizer_modules.0.embedding.weight', 'vocoder.quantizer.quantizer_modules.1.embedding.weight', 'vocoder.quantizer.quantizer_modules.2.embedding.weight', 'vocoder.quantizer.quantizer_modules.3.embedding.weight', 'vocoder.generator.conv_pre.bias', 'vocoder.generator.conv_pre.weight', 'vocoder.generator.ups.0.bias', 'vocoder.generator.ups.0.weight', 'vocoder.generator.ups.1.bias', 'vocoder.generator.ups.1.weight', 'vocoder.generator.ups.2.bias', 'vocoder.generator.ups.2.weight', 'vocoder.generator.ups.3.bias', 'vocoder.generator.ups.3.weight', 'vocoder.generator.resblocks.0.convs1.0.bias', 'vocoder.generator.resblocks.0.convs1.0.weight', 'vocoder.generator.resblocks.0.convs1.1.bias', 'vocoder.generator.resblocks.0.convs1.1.weight', 'vocoder.generator.resblocks.0.convs1.2.bias', 'vocoder.generator.resblocks.0.convs1.2.weight', 'vocoder.generator.resblocks.0.convs2.0.bias', 'vocoder.generator.resblocks.0.convs2.0.weight', 'vocoder.generator.resblocks.0.convs2.1.bias', 'vocoder.generator.resblocks.0.convs2.1.weight', 'vocoder.generator.resblocks.0.convs2.2.bias', 'vocoder.generator.resblocks.0.convs2.2.weight', 'vocoder.generator.resblocks.1.convs1.0.bias', 'vocoder.generator.resblocks.1.convs1.0.weight', 'vocoder.generator.resblocks.1.convs1.1.bias', 'vocoder.generator.resblocks.1.convs1.1.weight', 'vocoder.generator.resblocks.1.convs1.2.bias', 'vocoder.generator.resblocks.1.convs1.2.weight', 'vocoder.generator.resblocks.1.convs2.0.bias', 'vocoder.generator.resblocks.1.convs2.0.weight', 'vocoder.generator.resblocks.1.convs2.1.bias', 'vocoder.generator.resblocks.1.convs2.1.weight', 'vocoder.generator.resblocks.1.convs2.2.bias', 'vocoder.generator.resblocks.1.convs2.2.weight', 'vocoder.generator.resblocks.2.convs1.0.bias', 'vocoder.generator.resblocks.2.convs1.0.weight', 'vocoder.generator.resblocks.2.convs1.1.bias', 'vocoder.generator.resblocks.2.convs1.1.weight', 'vocoder.generator.resblocks.2.convs1.2.bias', 'vocoder.generator.resblocks.2.convs1.2.weight', 'vocoder.generator.resblocks.2.convs2.0.bias', 'vocoder.generator.resblocks.2.convs2.0.weight', 'vocoder.generator.resblocks.2.convs2.1.bias', 'vocoder.generator.resblocks.2.convs2.1.weight', 'vocoder.generator.resblocks.2.convs2.2.bias', 'vocoder.generator.resblocks.2.convs2.2.weight', 'vocoder.generator.resblocks.3.convs1.0.bias', 'vocoder.generator.resblocks.3.convs1.0.weight', 'vocoder.generator.resblocks.3.convs1.1.bias', 'vocoder.generator.resblocks.3.convs1.1.weight', 'vocoder.generator.resblocks.3.convs1.2.bias', 'vocoder.generator.resblocks.3.convs1.2.weight', 'vocoder.generator.resblocks.3.convs2.0.bias', 'vocoder.generator.resblocks.3.convs2.0.weight', 'vocoder.generator.resblocks.3.convs2.1.bias', 'vocoder.generator.resblocks.3.convs2.1.weight', 'vocoder.generator.resblocks.3.convs2.2.bias', 'vocoder.generator.resblocks.3.convs2.2.weight', 'vocoder.generator.resblocks.4.convs1.0.bias', 'vocoder.generator.resblocks.4.convs1.0.weight', 'vocoder.generator.resblocks.4.convs1.1.bias', 'vocoder.generator.resblocks.4.convs1.1.weight', 'vocoder.generator.resblocks.4.convs1.2.bias', 'vocoder.generator.resblocks.4.convs1.2.weight', 'vocoder.generator.resblocks.4.convs2.0.bias', 'vocoder.generator.resblocks.4.convs2.0.weight', 'vocoder.generator.resblocks.4.convs2.1.bias', 'vocoder.generator.resblocks.4.convs2.1.weight', 'vocoder.generator.resblocks.4.convs2.2.bias', 'vocoder.generator.resblocks.4.convs2.2.weight', 'vocoder.generator.resblocks.5.convs1.0.bias', 'vocoder.generator.resblocks.5.convs1.0.weight', 'vocoder.generator.resblocks.5.convs1.1.bias', 'vocoder.generator.resblocks.5.convs1.1.weight', 'vocoder.generator.resblocks.5.convs1.2.bias', 'vocoder.generator.resblocks.5.convs1.2.weight', 'vocoder.generator.resblocks.5.convs2.0.bias', 'vocoder.generator.resblocks.5.convs2.0.weight', 'vocoder.generator.resblocks.5.convs2.1.bias', 'vocoder.generator.resblocks.5.convs2.1.weight', 'vocoder.generator.resblocks.5.convs2.2.bias', 'vocoder.generator.resblocks.5.convs2.2.weight', 'vocoder.generator.resblocks.6.convs1.0.bias', 'vocoder.generator.resblocks.6.convs1.0.weight', 'vocoder.generator.resblocks.6.convs1.1.bias', 'vocoder.generator.resblocks.6.convs1.1.weight', 'vocoder.generator.resblocks.6.convs1.2.bias', 'vocoder.generator.resblocks.6.convs1.2.weight', 'vocoder.generator.resblocks.6.convs2.0.bias', 'vocoder.generator.resblocks.6.convs2.0.weight', 'vocoder.generator.resblocks.6.convs2.1.bias', 'vocoder.generator.resblocks.6.convs2.1.weight', 'vocoder.generator.resblocks.6.convs2.2.bias', 'vocoder.generator.resblocks.6.convs2.2.weight', 'vocoder.generator.resblocks.7.convs1.0.bias', 'vocoder.generator.resblocks.7.convs1.0.weight', 'vocoder.generator.resblocks.7.convs1.1.bias', 'vocoder.generator.resblocks.7.convs1.1.weight', 'vocoder.generator.resblocks.7.convs1.2.bias', 'vocoder.generator.resblocks.7.convs1.2.weight', 'vocoder.generator.resblocks.7.convs2.0.bias', 'vocoder.generator.resblocks.7.convs2.0.weight', 'vocoder.generator.resblocks.7.convs2.1.bias', 'vocoder.generator.resblocks.7.convs2.1.weight', 'vocoder.generator.resblocks.7.convs2.2.bias', 'vocoder.generator.resblocks.7.convs2.2.weight', 'vocoder.generator.resblocks.8.convs1.0.bias', 'vocoder.generator.resblocks.8.convs1.0.weight', 'vocoder.generator.resblocks.8.convs1.1.bias', 'vocoder.generator.resblocks.8.convs1.1.weight', 'vocoder.generator.resblocks.8.convs1.2.bias', 'vocoder.generator.resblocks.8.convs1.2.weight', 'vocoder.generator.resblocks.8.convs2.0.bias', 'vocoder.generator.resblocks.8.convs2.0.weight', 'vocoder.generator.resblocks.8.convs2.1.bias', 'vocoder.generator.resblocks.8.convs2.1.weight', 'vocoder.generator.resblocks.8.convs2.2.bias', 'vocoder.generator.resblocks.8.convs2.2.weight', 'vocoder.generator.resblocks.9.convs1.0.bias', 'vocoder.generator.resblocks.9.convs1.0.weight', 'vocoder.generator.resblocks.9.convs1.1.bias', 'vocoder.generator.resblocks.9.convs1.1.weight', 'vocoder.generator.resblocks.9.convs1.2.bias', 'vocoder.generator.resblocks.9.convs1.2.weight', 'vocoder.generator.resblocks.9.convs2.0.bias', 'vocoder.generator.resblocks.9.convs2.0.weight', 'vocoder.generator.resblocks.9.convs2.1.bias', 'vocoder.generator.resblocks.9.convs2.1.weight', 'vocoder.generator.resblocks.9.convs2.2.bias', 'vocoder.generator.resblocks.9.convs2.2.weight', 'vocoder.generator.resblocks.10.convs1.0.bias', 'vocoder.generator.resblocks.10.convs1.0.weight', 'vocoder.generator.resblocks.10.convs1.1.bias', 'vocoder.generator.resblocks.10.convs1.1.weight', 'vocoder.generator.resblocks.10.convs1.2.bias', 'vocoder.generator.resblocks.10.convs1.2.weight', 'vocoder.generator.resblocks.10.convs2.0.bias', 'vocoder.generator.resblocks.10.convs2.0.weight', 'vocoder.generator.resblocks.10.convs2.1.bias', 'vocoder.generator.resblocks.10.convs2.1.weight', 'vocoder.generator.resblocks.10.convs2.2.bias', 'vocoder.generator.resblocks.10.convs2.2.weight', 'vocoder.generator.resblocks.11.convs1.0.bias', 'vocoder.generator.resblocks.11.convs1.0.weight', 'vocoder.generator.resblocks.11.convs1.1.bias', 'vocoder.generator.resblocks.11.convs1.1.weight', 'vocoder.generator.resblocks.11.convs1.2.bias', 'vocoder.generator.resblocks.11.convs1.2.weight', 'vocoder.generator.resblocks.11.convs2.0.bias', 'vocoder.generator.resblocks.11.convs2.0.weight', 'vocoder.generator.resblocks.11.convs2.1.bias', 'vocoder.generator.resblocks.11.convs2.1.weight', 'vocoder.generator.resblocks.11.convs2.2.bias', 'vocoder.generator.resblocks.11.convs2.2.weight', 'vocoder.generator.conv_post.bias', 'vocoder.generator.conv_post.weight', 'vocoder.generator.spkr_linear.0.weight', 'vocoder.generator.spkr_linear.0.bias', 'vocoder.generator.spkr_linear.2.weight', 'vocoder.generator.spkr_linear.2.bias'])
Removing weight norm...
Removing weight norm...
Inferencing batch 1, total 41 baches.
Traceback (most recent call last):
File "/home/miniconda3/envs/mqtts/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/miniconda3/envs/mqtts/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/.vscode/extensions/ms-python.python-2023.4.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
cli.main()
File "/home/.vscode/extensions/ms-python.python-2023.4.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/.vscode/extensions/ms-python.python-2023.4.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="main")
File "/home/.vscode/extensions/ms-python.python-2023.4.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/.vscode/extensions/ms-python.python-2023.4.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/.vscode/extensions/ms-python.python-2023.4.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "infer.py", line 81, in
synthetic = model(i_wavs, i_phones)
File "/home/miniconda3/envs/mqtts/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/MQTTS/tester.py", line 71, in forward
synthetic = self.TTSdecoder.inference_topkp_sampling_batch(phone_features, speaker_embedding, phone_masks, prior=prior)
File "/home/MQTTS/modules/wildttstransformer.py", line 88, in inference_topkp_sampling_batch
inp = torch.cat([spkr, inp], 1)
RuntimeError: Tensors must have same number of dimensions: got 4 and 3

how does model infer during eval?

Since for inference it needs input sample, how does it infer during evaluation while training? Since it's a multi speaker dataset, it must be generating according to the speaker that's being tested. How does that happen? or it does not happen at all?

Question about batching

Hello, I try to re-train the model on GPU and come to some issues with batching and sampling.

When coming to the training loop, the sample has an effective batch size of 1, so each iteration equals one sample? Is it expected behaviour or am I doing something wrong? Is it connected to the custom bucket sampling used in the code? I tried to swap it with Standard sampler from PyTorch and when I tried to make a batch size bigger than 1, constant memory errors persisted. I have 48 GB RAM on a40 GPU. When profiled, it showed that hundreds of GB get swapped back-and-forth in 10 dev-run iterations for GigaSpeech samples.

So, is that a dataset issue, Lightning issue, or sampler issue?

Fine-tuning?

Hi! Are there readme instructions for how to fine-tune the model on a single speaker?

Some question about quantizer

Is quantizer related to language? Or can the quantizer trained by gigaSpeech be provided for mandarin datasets use? Need to retrain the quantizer on mandarin data?

Conditional transfromer

Hi Li-wei,
Vey interesting paper , I am interested in finding the MQTTS conditional transformer used in https://arxiv.org/pdf/2401.11053 for predicting audiodec 4 codes per timestep , it would be great if you can point me to that specific component in the code base , thanks

spaces before every input sentence ends?

It mostly misses last text characters while inferencing, have noticed that in the example input json, there are extra spaces to make it an appropriate speech. Are there any other workarounds, like hyperparameters or anything?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.