Giter Site home page Giter Site logo

Need help for training about styletts HOT 5 CLOSED

yl4579 avatar yl4579 commented on August 25, 2024
Need help for training

from styletts.

Comments (5)

yl4579 avatar yl4579 commented on August 25, 2024 1

You can use the following code for testing (under the inference notebook) from the validation loss computation code:

from meldataset import build_dataloader
train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
train_list, val_list = get_data_path_list(train_path, val_path)
train_dataloader = build_dataloader(train_list,
                                    batch_size=batch_size,
                                    num_workers=8,
                                    dataset_config={},
                                    device=device)

val_dataloader = build_dataloader(val_list,
                                  batch_size=batch_size,
                                  validation=True,
                                  num_workers=2,
                                  device=device,
                                  dataset_config={})

_, batch = next(enumerate(train_dataloader)) # can also be val_dataloader 
batch = [b.to(device) for b in batch]
texts, input_lengths, mels, mel_input_length = batch

with torch.no_grad():
                    mask = length_to_mask(mel_input_length // (2 ** model.text_aligner.n_down)).to('cuda')
                    m = length_to_mask(input_lengths)
                    ppgs, s2s_pred, s2s_attn_feat = model.text_aligner(mels, mask, texts)

                    s2s_attn_feat = s2s_attn_feat.transpose(-1, -2)
                    s2s_attn_feat = s2s_attn_feat[..., 1:]
                    s2s_attn_feat = s2s_attn_feat.transpose(-1, -2)

                    with torch.no_grad():
                        text_mask = length_to_mask(input_lengths).to(texts.device)
                        attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
                        attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
                        attn_mask = (attn_mask < 1)

                    s2s_attn_feat.masked_fill_(attn_mask, -float("inf"))

                    if TMA_CEloss:
                        s2s_attn = F.softmax(s2s_attn_feat, dim=1) # along the mel dimension
                    else:
                        s2s_attn = F.softmax(s2s_attn_feat, dim=-1) # along the text dimension

                    # get monotonic version 
                    with torch.no_grad():
                        mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** model.text_aligner.n_down))
                        s2s_attn_mono = maximum_path(s2s_attn, mask_ST)

                    s2s_attn = torch.nan_to_num(s2s_attn)

                # encode
                t_en = model.text_encoder(texts, input_lengths, m)
                asr = (t_en @ s2s_attn_mono)

                # get clips
                mel_len = int(mel_input_length.min().item() / 2 - 1)
                en = []
                gt = []
                for bib in range(len(mel_input_length)):
                    mel_length = int(mel_input_length[bib].item() / 2)

                    random_start = np.random.randint(0, mel_length - mel_len)
                    en.append(asr[bib, :, random_start:random_start+mel_len])
                    gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
                en = torch.stack(en)
                gt = torch.stack(gt).detach()

                with torch.no_grad():
                    F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
                    F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()

                # reconstruct
                s = model.style_encoder(gt.unsqueeze(1))
                real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
                mel_rec = model.decoder(en, F0_real, real_norm, s)
                mel_rec = mel_rec[..., :gt.shape[-1]]

synthesized = []
for idx in range(mel_rec.size(0)):
                with torch.no_grad():
                                # synthesize into waveforms
                                c = mel_rec[idx].squeeze()
                                y_g_hat = generator(c.unsqueeze(0))
                                y_out = y_g_hat.squeeze().cpu().numpy()
                synthesized.append(y_out)

import IPython.display as ipd
for wave in synthesized:
    display(ipd.Audio(wave, rate=24000))

As for the IPA error, it seems like only a few characters are not in the dictionary and they are automatically ignored by the text cleaner. You can do print(char) instead of print(index) to see which specific character.

from styletts.

yl4579 avatar yl4579 commented on August 25, 2024 1

@christopherohit If these characters are not essential for reconstruction then it doesn’t matter, because the text aligner is eventually finetuned with your new dataset and unseen characters of pretrained models will be learned during TMA training. But if these characters are important for reconstruction (like pauses, breaths, laughs etc.) then you need to retrain.

from styletts.

yl4579 avatar yl4579 commented on August 25, 2024

I believe the problem lies in the duration loss, it somehow fluctuates between 1 and 0.6. I think the text aligner is probably fine. Could you check if the first stage of the model sounds good?

from styletts.

nhanhttrong avatar nhanhttrong commented on August 25, 2024

Can you provide code to test stage 1? and I have another question, while preprocessing data i use phonemizer to convert text in data to IPA, but all of most IPA which was create by phonemes, aren't exist in your code,
so during training i get this status and it print text line which mismatch and loss still working but i wonder how it do that
image

and is raise keyerror in this code is effect to train ?
image

Thank all your reply, hope you have a good day

from styletts.

christopherohit avatar christopherohit commented on August 25, 2024

So does the IPA error affect training?
If yes, do I need to retrain with the more extensive IPA set through text-aligner?

Thank for your reply

from styletts.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.