When I used pytorch2onnx.py to convert onnx file, run reporte an error. The error is below
~/work/ai/MVT$ python tracking/pytorch2onnx.py
test config: {'MODEL': {'PRETRAIN_FILE': 'mobilevit_s.pt', 'EXTRA_MERGER': False, 'RETURN_INTER': False, 'RETURN_STAGES': [], 'BACKBONE': {'TYPE': 'mobilevit_s', 'STRIDE': 16, 'MID_PE': False, 'SEP_SEG': False, 'CAT_MODE': 'direct', 'MERGE_LAYER': 0, 'ADD_CLS_TOKEN': False, 'CLS_TOKEN_USE_MODE': 'ignore'}, 'NECK': {'TYPE': 'BN_FEATURE_FUSOR_LIGHTTRACK', 'NUM_CHANNS_POST_XCORR': 64}, 'HEAD': {'TYPE': 'CENTER', 'NUM_CHANNELS': 256}}, 'TRAIN': {'LR': 0.0004, 'WEIGHT_DECAY': 0.0001, 'EPOCH': 100, 'LR_DROP_EPOCH': 10, 'BATCH_SIZE': 128, 'NUM_WORKER': 10, 'OPTIMIZER': 'ADAMW', 'BACKBONE_MULTIPLIER': 0.1, 'GIOU_WEIGHT': 2.0, 'L1_WEIGHT': 5.0, 'FREEZE_LAYERS': [0], 'PRINT_INTERVAL': 50, 'VAL_EPOCH_INTERVAL': 10, 'GRAD_CLIP_NORM': 0.1, 'AMP': False, 'SCHEDULER': {'TYPE': 'cosine_anneal', 'DECAY_RATE': 0.5}}, 'DATA': {'SAMPLER_MODE': 'causal', 'MEAN': [0.0, 0.0, 0.0], 'STD': [1.0, 1.0, 1.0], 'MAX_SAMPLE_INTERVAL': 200, 'TRAIN': {'DATASETS_NAME': ['GOT10K_train_full'], 'DATASETS_RATIO': [1], 'SAMPLE_PER_EPOCH': 60000}, 'VAL': {'DATASETS_NAME': ['GOT10K_official_val'], 'DATASETS_RATIO': [1], 'SAMPLE_PER_EPOCH': 10000}, 'SEARCH': {'SIZE': 256, 'FACTOR': 4.0, 'CENTER_JITTER': 3, 'SCALE_JITTER': 0.25, 'NUMBER': 1}, 'TEMPLATE': {'NUMBER': 1, 'SIZE': 128, 'FACTOR': 2.0, 'CENTER_JITTER': 0, 'SCALE_JITTER': 0}}, 'TEST': {'DEVICE': 'cpu', 'TEMPLATE_FACTOR': 2.0, 'TEMPLATE_SIZE': 128, 'SEARCH_FACTOR': 4.0, 'SEARCH_SIZE': 256, 'EPOCH': 100}}
Converting tracking model now!
Traceback (most recent call last):
File "tracking/pytorch2onnx.py", line 223, in
convert_tracking_model(network, params.checkpoint)
File "tracking/pytorch2onnx.py", line 188, in convert_tracking_model
opset_version=11, do_constant_folding=True, input_names=['z','x'], output_names=['cls','reg'])
File "/home/999/.local/lib/python3.6/site-packages/torch/onnx/init.py", line 320, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/999/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 111, in export
custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
File "/home/999/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 729, in _export
dynamic_axes=dynamic_axes)
File "/home/999/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/999/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/999/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/home/999/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 132, in forward
self._force_outplace,
File "/home/999/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "tracking/pytorch2onnx.py", line 52, in forward
x, z = self.backbone(x=search, z=template)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/999/work/ai/MVT/lib/models/mobilevit_track/base_backbone.py", line 93, in forward
x, z = self.forward_features(x, z,)
File "/home/999/work/ai/MVT/lib/models/mobilevit_track/base_backbone.py", line 74, in forward_features
x, z = self._forward_MobileViT_layer(self.layer_3, x, z)
File "/home/999/work/ai/MVT/lib/models/mobilevit_track/base_backbone.py", line 46, in _forward_MobileViT_layer
z = MobilenetV2_block(z)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/999/work/ai/MVT/lib/models/mobilevit_track/modules/mobilenetv2.py", line 240, in forward
return self.block(x)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/999/work/ai/MVT/lib/models/mobilevit_track/layers/conv_layer.py", line 236, in forward
return self.block(x)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/999/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [256, 64, 1, 1], expected input[1, 3, 128, 128] to have 64 channels, but got 3 channels instead