Giter Site home page Giter Site logo

simple-tts's Issues

Mismatched batch dimensions in Unet1D.forward when prefix inpainting is enabled

I'm wrestling with a puzzling tensor shape error when running inference, but only in the case that I set prefix_inpainting_seconds=3.0 (everything works when I set it to 0 -- I'm generating non-speaker-prompted audio just fine):

  File "/root/ml/simple_tts/models/unet.py", line 460, in forward
    x = torch.cat((x, r), dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 63 but got size 1 for tensor number 1 in the list.

Before this line x and r have shape (63, 512, 1504) and (1, 512, 1504). x's batch dimension changes from 1 to 63 due to a broadcast at the line x = x * (scale + 1) + shift in Block.forward. I've traced this back to the fact that the values of tokenizer_output in GaussianDiffusion.sample have shape (63, 256), and this gets passed through to text_cond and then mean_pooled_context in the Unet. I'm not sure if this is the intended behavior.

Here's my full config:

args = argparse.Namespace(
      dataset_name="mls",
      save_dir="saved_models",
      text_encoder="google/byt5-large",
      output_dir=DIR,
      resume_dir=DIR,
      init_model=None,
      run_name="test/sample16",
      seed=None,
      dim=512,
      dim_mults=(1.0, 1.0, 1.0, 1.5),
      conformer_transformer=False,
      scale_skip_connection=True,
      num_transformer_layers=12,
      dropout=0.0,
      inpainting_embedding=True,
      optimizer="adamw",
      batch_size=16,
      num_train_steps=200000,
      gradient_accumulation_steps=2,
      learning_rate=0.0001,
      clip_grad_norm=1.0,
      lr_schedule="cosine",
      lr_warmup_steps=1000,
      adam_beta1=0.9,
      adam_beta2=0.999,
      adam_weight_decay=0,
      ema_decay=0.9999,
      objective="pred_v",
      parameterization="pred_v",
      loss_type="l1",
      train_schedule="cosine",
      sampling_schedule=None,
      resume=True,
      scale=0.5,
      sampling_timesteps=250,
      unconditional_prob=0.1,
      inpainting_prob=0.5,
      save_and_sample_every=5000,
      num_samples=16,
      sampler="ddpm",
      ddpm_var="large_var",
      prefix_inpainting_seconds=3.0,
      mixed_precision="no",
      eval=False,
      eval_test=True,
      trainable_params=243399440,
      num_devices=8,
      guidance=[5.0],
  )

jj

I train

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.