multi-band-wavernn's People
Forkers
tricky61 inconnu11 entn-at avatarworld zhuxiaoxuhit hxtuniverse x-ccs macroustc yingfenging ginchow iwaterxtmulti-band-wavernn's Issues
你好,询问一下Multi-band-WaveRNN对于多频带的输入输出
你好!想请教一下!Multi-band-WaveRNN对于多频带的输入输出
我使用的是fatchord_version.的wavernn,代码放在我的仓库中。但是生成不出有意义的声音,我怀疑模型输入的时候mel condition出了问题。想请教一下您:
我是将:四个频带音频x维度为(batch,T,4),mel为(batch,T,80),上采样得到的残差边为(batch,T,32)。将他们cat后传入网络中,因此网络不同之处是输入音频x由(batch,T,1)变为(batch,T,4),同时输出使用四个全连接层生成四个频带音频,并且得到四个频带的loss相加并反向传播。
我查看了论文DURIAN中图,输入特征四个频带的音频似乎是也是在高维上(batch,T,4)扩展而不是时间步上扩展成(batch,4T,1),想向您证实并交流一下,谢谢!
训练部分:
for i, (x, y, m) in enumerate(train_set, 1):
x, m, y = x.to(device), m.to(device), y.to(device) # x/y: (Batch, sub_bands, T)
######################### MultiBand-WaveRNN #########################
if hp.voc_multiband:
y0 = y[:, 0, :].squeeze(0).unsqueeze(-1) # y0/y1/y2/y3: (Batch, T, 1)
y1 = y[:, 1, :].squeeze(0).unsqueeze(-1)
y2 = y[:, 2, :].squeeze(0).unsqueeze(-1)
y3 = y[:, 3, :].squeeze(0).unsqueeze(-1)
y_hat = model(x, m) # (Batch, T, num_classes, sub_bands)
if model.mode == 'RAW':
y_hat0 = y_hat[:, :, :, 0].transpose(1,2).unsqueeze(-1) # (Batch, num_classes, T, 1)
y_hat1 = y_hat[:, :, :, 1].transpose(1,2).unsqueeze(-1)
y_hat2 = y_hat[:, :, :, 2].transpose(1,2).unsqueeze(-1)
y_hat3 = y_hat[:, :, :, 3].transpose(1,2).unsqueeze(-1)
elif model.mode == 'MOL':
y0 = y0.float()
y1 = y1.float()
y2 = y2.float()
y3 = y3.float()
loss = loss_func(y_hat0, y0) + loss_func(y_hat1, y1) + loss_func(y_hat2, y2) + loss_func(y_hat3, y3)
模型结构:
def forward(self, x, mels): # x: (Batch, Subband, T)
device = next(self.parameters()).device # use same device as parameters
# Although we `_flatten_parameters()` on init, when using DataParallel
# the model gets replicated, making it no longer guaranteed that the
# weights are contiguous in GPU memory. Hence, we must call it again
self._flatten_parameters()
self.step += 1
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.rnn_dims, device=device)
h2 = torch.zeros(1, bsize, self.rnn_dims, device=device)
mels, aux = self.upsample(mels)
aux_idx = [self.aux_dims * i for i in range(5)]
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
# x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
x = torch.cat([x.transpose(1,2), mels, a1], dim=2) # (batch,T,4) (batch,T,80) (batch,T,32)
x = self.I(x) # (batch, T, 116) -> # (batch, T, 512)
res = x
# x, _ = self.rnn1(x, h1)
x, _ = self.rnn1(x) # 不加入隐藏层-Begee # (batch, T, 512) -> (batch, T, 512)
x = x + res # (batch, T, 512)
res = x
x = torch.cat([x, a2], dim=2) # (batch, T, 512) -> (batch, T, 512+128)
# x, _ = self.rnn2(x, h2)
x, _ = self.rnn2(x) # 不加入隐藏层-Begee (batch, T, 512+128) -> (batch, T, 512)
x = x + res
x = torch.cat([x, a3], dim=2) # (batch, T, 512+128)
x = F.relu(self.fc1(x)) # (batch, T, 512+128) -> (batch, T, 512)
x = torch.cat([x, a4], dim=2) # (batch, T, 512+128)
x = F.relu(self.fc2(x)) # (batch, T, 512+128) -> (batch, T, 512)
out0 = self.fc30(x).unsqueeze(-1) # (batch, T, 512) -> (batch, T, 512)
out1 = self.fc31(x).unsqueeze(-1)
out2 = self.fc32(x).unsqueeze(-1)
out3 = self.fc33(x).unsqueeze(-1)
out = torch.cat([out0,out1,out2,out3], dim=3) # (B, T, num_classes, sub_band)
return out
感谢您的解答!
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.