The bottleneck is in the SMPLX_Util.get_body_vertices_sequence, since it will load the smplx pretrained weights repeatedly. For example, there are 1319 examples of action walk
, and the k == 10
, then the number of times of the loading process will be 13190, making the IO time extremely long. My suggestion is: simply instance 1 smplx model with batch_size=max_motion_len
, and select the unmasked smpl parameters after the inference of smplx model.
Here is my code to test the time cost of three modes:
import torch
import smplx
import mmengine as me
test_mode = 'cuda' # cpu, cuda, cuda_static
device = 'cpu'
if test_mode in ['cuda', 'cuda_static']:
device = 'cuda'
seq_len = 60
torch_param = dict()
torch_param['body_pose'] = torch.randn(seq_len, 63).to(device)
torch_param['betas'] = torch.randn(seq_len, 10).to(device)
torch_param['transl'] = torch.randn(seq_len, 3).to(device)
torch_param['global_orient'] = torch.randn(seq_len, 3).to(device)
torch_param['left_hand_pose'] = torch.randn(seq_len, 45).to(device)
torch_param['right_hand_pose'] = torch.randn(seq_len, 45).to(device)
static_model = smplx.create(model_path='data/models_smplx_v1_1/models',
model_type='smplx',
gender='neutral',
num_betas=10,
use_pca=False,
batch_size=seq_len,
ext='npz')
static_model = static_model.to(device)
for i in me.track_iter_progress(range(100)):
if test_mode in ['cpu', 'cuda']:
model = smplx.create(model_path='data/models_smplx_v1_1/models',
model_type='smplx',
gender='neutral',
num_betas=10,
use_pca=False,
batch_size=seq_len,
ext='npz').to(device)
output = model(return_verts=True, **torch_param)
elif test_mode == 'cuda_static':
output = static_model(return_verts=True, **torch_param)
When test_mode = 'cpu'
:
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 100/100, 1.7 task/s, elapsed: 58s, ETA: 0s
When test_mode = 'cuda'
:
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 100/100, 1.7 task/s, elapsed: 60s, ETA: 0s
When test_mode = 'cuda_static'
:
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 100/100, 40.2 task/s, elapsed: 2s, ETA: 0s
40x faster when the seq_len=60
, half of the max_motion_len
.