Comments (8)
第二部分 onnx 是支持for 循环导出的这里我有时间整一下, 你可以搜一下torch for loop to onnx; 或者可以对齐下parallel cif的实现 提个pr过来
感谢!那我这里先试一下onnx for循环导出的问题,看能不能解决;
但是推理耗时不稳定的问题,会是CIF这块儿的问题吗?
应该是 其他结构都类transformer 推理应该很稳定
from wenet.
关注一下cif那部分的转onnx
是的,确实主要是CIF那块儿的问题;
我做了一些尝试,目前可以成功转换支持动态输入的 onnx-gpu 模型,但是还是存在一些问题。
- wenet-main/examples/aishell/paraformer/wenet/utils/mask.py
# 这里加item(),会导致onnx转换模型无法支持动态维度,用netron查看网络结构,会发现这里会变成一个固定值
# max_len = max_len if max_len > 0 else lengths.max().item()
max_len = lengths.max()
- wenet-main/examples/aishell/paraformer/wenet/paraformer/cif.py
class Cif(nn.Module):
def forward():
if target_length is None and self.tail_threshold > 0.0:
# 这块儿好像也有点问题,会提示 int32 和 int64 不兼容的问题
# token_num_int = torch.max(token_num).type(torch.int32).item()
token_num_int = torch.max(token_num).type(torch.int64)
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
- wenet-main/examples/aishell/paraformer/wenet/paraformer/cif.py
-
尝试
- 此处的CIF函数,主体使用的是for循环,直接导出的话,其层数会固定,与onnx转换时的层数一致
- 增加 @torch.jit.script,虽然支持动态维度,但是得到的模型耗时严重
-
解决
- 将其替换成 https://github.com/George0828Zhang/torch_cif 中的并行处理方式,可以支持动态输入
-
现阶段问题
- 模型推理耗时不稳定,对于padding到同一长度的音频文件(60s),for循环测试,耗时从150ms ~ 2000ms不等;
- 性能方面有一定的损失;
-
from wenet.
关注一下cif那部分的转onnx
from wenet.
第一部分会重构下这个函数 这个函数也会影响torch.compile等函数
第二部分 onnx 是支持for 循环导出的这里我有时间整一下, 你可以搜一下torch for loop to onnx; 或者可以对齐下parallel cif的实现 提个pr过来
from wenet.
@whisper-yu #2515 帮忙mask试下这个 🙏
from wenet.
第二部分 onnx 是支持for 循环导出的这里我有时间整一下, 你可以搜一下torch for loop to onnx; 或者可以对齐下parallel cif的实现 提个pr过来
感谢!那我这里先试一下onnx for循环导出的问题,看能不能解决;
但是推理耗时不稳定的问题,会是CIF这块儿的问题吗?
from wenet.
应该是 其他结构都类transformer 推理应该很稳定
我做了一些测试,但是跟我原本猜想的好像不太一样:
- 实际上慢的部分是decoder(并且是时快时慢,反而encoder predictor部分的耗时很稳定);
- 这块儿我经验不多,我有考虑是否是资源受限的问题,但感觉又不合理
- 性能有一定的衰减,自有数据集上字错率由30%(funasr接口直测)涨到了36%,感觉可能 CIF没对齐的问题(因为主要更改了这个部分);
测试过程
我把 encoder predictor decoder 分别导出为 onnx-gpu模型,然后单独测试其耗时;
1. 测试结果【测试音频不同】
2. 测试代码:
def infer_onnx(wav_path, model, tokenizer):
# pre
start_0 = time.time()
wav, sr = torchaudio.load(wav_path)
# padding
padding_length = int(60 * sr - wav.shape[1])
padding = torch.zeros(1, padding_length) + 0.00001
wav = torch.cat([wav, padding], dim=1)
data = wav.squeeze()
data = [data]
speech, speech_lengths = extract_fbank(data)
# 这里没有将LFR放入 encoder,是因为 其中有算子不支持!!!
lfr = LFR()
speech, speech_lengths = lfr(speech, speech_lengths)
end_0 = time.time()
total_0 = int((end_0 - start_0) * 1000)
# encoder
start_1 = time.time()
encoder_inputs = {
"speech": to_numpy(speech),
"speech_lengths": to_numpy(speech_lengths),
}
encoder_out, encoder_out_mask = encoder_session.run(None, encoder_inputs)
end_1 = time.time()
total_1 = int((end_1 - start_1) * 1000)
# predictor
start_2 = time.time()
predictor_inputs = {
"encoder_out": encoder_out,
"encoder_out_mask": encoder_out_mask,
}
acoustic_embed, token_num, tp_alphas = predictor_session.run(None, predictor_inputs)
end_2 = time.time()
total_2 = int((end_2 - start_2) * 1000)
# decoder
start_3 = time.time()
decoder_inputs = {
"encoder_out": encoder_out,
"encoder_out_mask": encoder_out_mask,
"acoustic_embed": acoustic_embed,
"token_num": token_num,
}
decoder_out = decoder_session.run(None, decoder_inputs)
decoder_out = decoder_out[0]
end_3 = time.time()
total_3 = int((end_3 - start_3) * 1000)
# post
start_4 = time.time()
decoder_out = torch.tensor(decoder_out, dtype=torch.float32)
decoder_out_lens = torch.tensor(token_num, dtype=torch.int32)
tp_alphas = torch.tensor(tp_alphas, dtype=torch.float32)
peaks = model.forward_cif_peaks(tp_alphas, decoder_out_lens)
paraformer_greedy_result = paraformer_greedy_search(
decoder_out, decoder_out_lens, peaks)
results = {
"paraformer_greedy_result": paraformer_greedy_result
}
for i in range(len(data)):
for mode, hyps in results.items():
tokens = hyps[i].tokens
line = '{}'.format(tokenizer.detokenize(tokens)[0])
end_4 = time.time()
total_4 = int((end_4 - start_4) * 1000)
print(f"[pre]-{total_0} ||[encoder]-{total_1} ||[predictor]-{total_2} ||[decoder]-{total_3} ||[post]-{total_4}")
return line
3. decoder onnx 导出代码
# forward部分
class Paraformer(ASRModel):
# DECODER
@torch.jit.export
def forward_decoder(
self,
encoder_out: torch.Tensor,
encoder_out_mask: torch.Tensor,
acoustic_embed: torch.Tensor,
token_num: torch.Tensor
) -> torch.Tensor:
# decoder
decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask,
acoustic_embed, token_num)
decoder_out = decoder_out.log_softmax(dim=-1)
return decoder_out
##############################################################
# decoder onnx模型导出部分
if not os.path.exists(decoder_path):
print("\n\n[export decoder]")
model.forward = model.forward_decoder
torch.onnx.export(
model,
(encoder_out, encoder_out_mask, acoustic_embed, token_num),
decoder_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=["encoder_out", "encoder_out_mask", "acoustic_embed", "token_num"],
output_names=[
"decoder_out",
],
dynamic_axes={
"encoder_out": {
0: "B",
1: "T_E"
},
"encoder_out_mask": {
0: "B",
2: 'T_E'
},
"acoustic_embed":{
0: "B",
1: "T_P"
},
"token_num":{
0: "B"
},
"decoder_out":{
0: "B",
1: "T_P"
},
},
verbose=True,
)
from wenet.
应该是 其他结构都类transformer 推理应该很稳定
我做了一些测试,但是跟我原本猜想的好像不太一样:
实际上慢的部分是decoder(并且是时快时慢,反而encoder predictor部分的耗时很稳定);
- 这块儿我经验不多,我有考虑是否是资源受限的问题,但感觉又不合理
性能有一定的衰减,自有数据集上字错率由30%(funasr接口直测)涨到了36%,感觉可能 CIF没对齐的问题(因为主要更改了这个部分);
测试过程
我把 encoder predictor decoder 分别导出为 onnx-gpu模型,然后单独测试其耗时;
1. 测试结果【测试音频不同】
2. 测试代码:
def infer_onnx(wav_path, model, tokenizer): # pre start_0 = time.time() wav, sr = torchaudio.load(wav_path) # padding padding_length = int(60 * sr - wav.shape[1]) padding = torch.zeros(1, padding_length) + 0.00001 wav = torch.cat([wav, padding], dim=1) data = wav.squeeze() data = [data] speech, speech_lengths = extract_fbank(data) # 这里没有将LFR放入 encoder,是因为 其中有算子不支持!!! lfr = LFR() speech, speech_lengths = lfr(speech, speech_lengths) end_0 = time.time() total_0 = int((end_0 - start_0) * 1000) # encoder start_1 = time.time() encoder_inputs = { "speech": to_numpy(speech), "speech_lengths": to_numpy(speech_lengths), } encoder_out, encoder_out_mask = encoder_session.run(None, encoder_inputs) end_1 = time.time() total_1 = int((end_1 - start_1) * 1000) # predictor start_2 = time.time() predictor_inputs = { "encoder_out": encoder_out, "encoder_out_mask": encoder_out_mask, } acoustic_embed, token_num, tp_alphas = predictor_session.run(None, predictor_inputs) end_2 = time.time() total_2 = int((end_2 - start_2) * 1000) # decoder start_3 = time.time() decoder_inputs = { "encoder_out": encoder_out, "encoder_out_mask": encoder_out_mask, "acoustic_embed": acoustic_embed, "token_num": token_num, } decoder_out = decoder_session.run(None, decoder_inputs) decoder_out = decoder_out[0] end_3 = time.time() total_3 = int((end_3 - start_3) * 1000) # post start_4 = time.time() decoder_out = torch.tensor(decoder_out, dtype=torch.float32) decoder_out_lens = torch.tensor(token_num, dtype=torch.int32) tp_alphas = torch.tensor(tp_alphas, dtype=torch.float32) peaks = model.forward_cif_peaks(tp_alphas, decoder_out_lens) paraformer_greedy_result = paraformer_greedy_search( decoder_out, decoder_out_lens, peaks) results = { "paraformer_greedy_result": paraformer_greedy_result } for i in range(len(data)): for mode, hyps in results.items(): tokens = hyps[i].tokens line = '{}'.format(tokenizer.detokenize(tokens)[0]) end_4 = time.time() total_4 = int((end_4 - start_4) * 1000) print(f"[pre]-{total_0} ||[encoder]-{total_1} ||[predictor]-{total_2} ||[decoder]-{total_3} ||[post]-{total_4}") return line
3. decoder onnx 导出代码
# forward部分 class Paraformer(ASRModel): # DECODER @torch.jit.export def forward_decoder( self, encoder_out: torch.Tensor, encoder_out_mask: torch.Tensor, acoustic_embed: torch.Tensor, token_num: torch.Tensor ) -> torch.Tensor: # decoder decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask, acoustic_embed, token_num) decoder_out = decoder_out.log_softmax(dim=-1) return decoder_out ############################################################## # decoder onnx模型导出部分 if not os.path.exists(decoder_path): print("\n\n[export decoder]") model.forward = model.forward_decoder torch.onnx.export( model, (encoder_out, encoder_out_mask, acoustic_embed, token_num), decoder_path, export_params=True, opset_version=13, do_constant_folding=True, input_names=["encoder_out", "encoder_out_mask", "acoustic_embed", "token_num"], output_names=[ "decoder_out", ], dynamic_axes={ "encoder_out": { 0: "B", 1: "T_E" }, "encoder_out_mask": { 0: "B", 2: 'T_E' }, "acoustic_embed":{ 0: "B", 1: "T_P" }, "token_num":{ 0: "B" }, "decoder_out":{ 0: "B", 1: "T_P" }, }, verbose=True, )
Is this issue fixed?
from wenet.
Related Issues (20)
- Segmentfault in multiprocessing DataLoader when training on Kunpeng cpu HOT 1
- 在ubuntu中编译生成arm32的可执行文件时报错
- Watchdog caught collective operation timeout HOT 3
- [device args support] Replace `gpu` args with `device` args in train.py/recognize.py, etc.
- During the later epochs, the training speed decreases by 2-3 times(3.0.1)
- Does wenet support saving best model during training? Or having early stopping scheme HOT 1
- Multi-Query Attention failed to export onnx model HOT 2
- arm平台运行onnxruntime报错 HOT 1
- 降噪之后的音频推理准确度下降 HOT 2
- Whisper finetuning support for other languages HOT 1
- Triton Server - support of Unified Conformer model fails HOT 6
- MOE 模型是否可以支持导出onnx HOT 1
- Can paraformer be inferenced with gpu and runtime ?
- NotImplemented: Subclasses of Dataset should imlement_getitem_. HOT 1
- 部署到安卓,错误率很高 HOT 1
- update torch to 2.3.0+cu121, torchaudio fail in func tar_file_and_group of wenet/dataset/datapipes.py HOT 1
- How can I set beam_size in Android runtime?
- 在两人对话的场景中,未来有计划增加区分对话人的功能特性吗
- 使用命令时报错 HOT 1
- 训练的模型后验为空的问题
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from wenet.