Comments (7)
Running a huge chunk of frames requires a lot of memory, for example, the resnet computes dense feature maps for every frame and store them in the GPU memory. If you are tracking multiple points simultaneously, that also requires allocating memory for each point during iterative updates.
There can be some memory saving tricks. For the simplicity, you could try chunk your 300 frames into smaller clips with overlapping (i.e. frame 0-125, frame 75 - 225, ...), predict point tracks, throw away the predictions around the window boundary.
Note that the model is trained only with 24 frames during training. I observed that there is around 30 frames context window that the model truly used during the 1d temporal convolution. Beyond that, there is not much big difference especially when points are predicted as visible.
from tapnet.
You also need to precompute your query_feature beforehand, and run model prediction with your precomputed query_feature. Code for doing that is provided below (not throughly checked):
def build_model_init(frames, query_points, *, exp_config):
kwargs = exp_config.shared_modules.tapir_model_kwargs
kwargs.bilinear_interp_with_depthwise_conv = False
model = tapir_model.TAPIR(**kwargs)
feature_grids = model.get_feature_grids(frames, is_training=False)
query_features = model.get_query_features(
frames,
is_training=False,
query_points=query_points,
feature_grids=feature_grids,
)
return query_features
def build_model_predict(frames, query_features, *, exp_config):
"""Compute point tracks and occlusions given frames and query points."""
kwargs = exp_config.shared_modules.tapir_model_kwargs
kwargs.bilinear_interp_with_depthwise_conv = False
model = tapir_model.TAPIR(**kwargs)
feature_grids = model.get_feature_grids(frames, is_training=False)
trajectories= model.estimate_trajectories(
frames.shape[-3:-1],
is_training=False,
feature_grids=feature_grids,
query_features=query_features,
query_points_in_video=None,
query_chunk_size=64,
)
return {k:v[-1] for k, v in trajectories.items()}
model_init_fn = functools.partial(build_model_init, exp_config=exp_config)
init = hk.transform_with_state(model_init_fn)
init_apply = jax.jit(init.apply)
model_predict_fn = functools.partial(build_model_predict, exp_config=exp_config)
predict = hk.transform_with_state(model_predict_fn)
predict_apply = jax.jit(predict.apply)
from tapnet.
@yangyi02 Can this trick used also for training?
I want to ask if you also a raw estimate about the memory O of the build_model_init
feature extractor/backbone vs the tracking network component.
Is the feature extractor backbone memory consumption dominant?
from tapnet.
Yes, this trick can also be used for training if you call the relevant functions. I don't see what would prevent you from doing that.
Regarding memory, we haven't done careful analysis, so this is not an easy question to answer; it really depends on the number of query points. For inference, peak memory usage will probably be due to the hidden layers with (2048 channels) * (number of points) * (number of frames). At training time, of course, this is multiplied by the number of layers, which is something like (4 PIPs iterations) * (12 blocks) * (4 layers per block). In my experience, after about 32 points this starts to become prohibitive for training (hence why I chunk them and only backprop a subset). Feature backbone is relatively small (similar to ResNet-18), but it's applied across the whole video, so for long videos can be problematic, especially if the video is long. Sorry I can't be more specific.
from tapnet.
Feature backbone is relatively small (similar to ResNet-18), but it's applied across the whole video, so for long videos can be problematic, especially if the video is long. Sorry I can't be more specific.
I meant fixing the point/frames scaling the input resolution a lot at training time (e.g. re-rendering the synth dataset) it could let the feature backbone to be the memory dominant component right?
from tapnet.
Running a huge chunk of frames requires a lot of memory, for example, the resnet computes dense feature maps for every frame and store them in the GPU memory. If you are tracking multiple points simultaneously, that also requires allocating memory for each point during iterative updates.
There can be some memory saving tricks. For the simplicity, you could try chunk your 300 frames into smaller clips with overlapping (i.e. frame 0-125, frame 75 - 225, ...), predict point tracks, throw away the predictions around the window boundary.
Note that the model is trained only with 24 frames during training. I observed that there is around 30 frames context window that the model truly used during the 1d temporal convolution. Beyond that, there is not much big difference especially when points are predicted as visible.
Hi, thank you for the great work. I'm using TAPIR to do inference on long video as well. Can you be more specific about what you mean by "throw away the predictions around the window boundary" ? If I understand correctly you mean throw away the predictions in the overlap frames between two chunks ? Can you explain why we need to do it ? Thank you in advance! I'm sorry if the question is too obvious.
from tapnet.
For each frame, TAPIR needs to see enough temporal context frames to make the best possible prediction. However due to memory limitation, the boundary frames (i.e. frame 300 in your case) will not see its future context (i.e. frame 301-315), so the model will not work great around there. But still you can trust the model prediction between frame 0 and 270.
Now if you want a better prediction after 270, you can extract frames from 240 to 540 and run model prediction again. Ideally the model could give you a reliable prediction for frame 270 to 510.
Hope that explains.
from tapnet.
Related Issues (20)
- OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input. HOT 5
- ValueError: All `hk.Module`s must be initialized inside an `hk.transform` HOT 12
- ValueError: converting frame count is not supported. HOT 8
- robotap's query points selection question HOT 5
- Torchscript compatibility HOT 6
- Has anyone implemented it with tensorrt? HOT 1
- None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms HOT 2
- Pytorch <2.1.0 can't load the checkpoints correctly HOT 1
- Pretrained Weights for Pytorch Version of Online Tapir/BootsTapir HOT 2
- IndexError: boolean index did not match indexed array along dimension 1; dimension is 256 but corresponding boolean dimension is 990 HOT 2
- Training TAPIR PyTorch version script? HOT 7
- BootsTAP Training Dataset HOT 1
- `plot_tracks_v2` has bug when plotting with `trackgroup` argument. HOT 2
- KeyError: 'global_step' When I load the weight of TAPIR HOT 5
- CUDA out of memory issue when using PyTorch weights instead of JAX weights. HOT 2
- pytorch version TAPIR 's training file HOT 1
- Annotation Tool for TAP-VID HOT 2
- TAPIR PyTorch checkpoint size mismatch with model HOT 5
- TAPIR training time stats HOT 2
- TAPIR performance degradation with cudnn9 HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tapnet.