Giter Site home page Giter Site logo

OOM error about tapnet HOT 7 OPEN

google-deepmind avatar google-deepmind commented on July 28, 2024
OOM error

from tapnet.

Comments (7)

yangyi02 avatar yangyi02 commented on July 28, 2024

Running a huge chunk of frames requires a lot of memory, for example, the resnet computes dense feature maps for every frame and store them in the GPU memory. If you are tracking multiple points simultaneously, that also requires allocating memory for each point during iterative updates.

There can be some memory saving tricks. For the simplicity, you could try chunk your 300 frames into smaller clips with overlapping (i.e. frame 0-125, frame 75 - 225, ...), predict point tracks, throw away the predictions around the window boundary.

Note that the model is trained only with 24 frames during training. I observed that there is around 30 frames context window that the model truly used during the 1d temporal convolution. Beyond that, there is not much big difference especially when points are predicted as visible.

from tapnet.

yangyi02 avatar yangyi02 commented on July 28, 2024

You also need to precompute your query_feature beforehand, and run model prediction with your precomputed query_feature. Code for doing that is provided below (not throughly checked):

def build_model_init(frames, query_points, *, exp_config):
  kwargs = exp_config.shared_modules.tapir_model_kwargs
  kwargs.bilinear_interp_with_depthwise_conv = False
  model = tapir_model.TAPIR(**kwargs)
  feature_grids = model.get_feature_grids(frames, is_training=False)
  query_features = model.get_query_features(
      frames,
      is_training=False,
      query_points=query_points,
      feature_grids=feature_grids,
  )
  return query_features

def build_model_predict(frames, query_features, *, exp_config):
  """Compute point tracks and occlusions given frames and query points."""
  kwargs = exp_config.shared_modules.tapir_model_kwargs
  kwargs.bilinear_interp_with_depthwise_conv = False
  model = tapir_model.TAPIR(**kwargs)
  feature_grids = model.get_feature_grids(frames, is_training=False)
  trajectories= model.estimate_trajectories(
      frames.shape[-3:-1],
      is_training=False,
      feature_grids=feature_grids,
      query_features=query_features,
      query_points_in_video=None,
      query_chunk_size=64,
  )
  return {k:v[-1] for k, v in trajectories.items()}

model_init_fn = functools.partial(build_model_init, exp_config=exp_config)
init = hk.transform_with_state(model_init_fn)
init_apply = jax.jit(init.apply)
model_predict_fn = functools.partial(build_model_predict, exp_config=exp_config)
predict = hk.transform_with_state(model_predict_fn)
predict_apply = jax.jit(predict.apply)

from tapnet.

bhack avatar bhack commented on July 28, 2024

@yangyi02 Can this trick used also for training?

I want to ask if you also a raw estimate about the memory O of the build_model_init feature extractor/backbone vs the tracking network component.
Is the feature extractor backbone memory consumption dominant?

from tapnet.

cdoersch avatar cdoersch commented on July 28, 2024

Yes, this trick can also be used for training if you call the relevant functions. I don't see what would prevent you from doing that.

Regarding memory, we haven't done careful analysis, so this is not an easy question to answer; it really depends on the number of query points. For inference, peak memory usage will probably be due to the hidden layers with (2048 channels) * (number of points) * (number of frames). At training time, of course, this is multiplied by the number of layers, which is something like (4 PIPs iterations) * (12 blocks) * (4 layers per block). In my experience, after about 32 points this starts to become prohibitive for training (hence why I chunk them and only backprop a subset). Feature backbone is relatively small (similar to ResNet-18), but it's applied across the whole video, so for long videos can be problematic, especially if the video is long. Sorry I can't be more specific.

from tapnet.

bhack avatar bhack commented on July 28, 2024

Feature backbone is relatively small (similar to ResNet-18), but it's applied across the whole video, so for long videos can be problematic, especially if the video is long. Sorry I can't be more specific.

I meant fixing the point/frames scaling the input resolution a lot at training time (e.g. re-rendering the synth dataset) it could let the feature backbone to be the memory dominant component right?

from tapnet.

khiem2105 avatar khiem2105 commented on July 28, 2024

Running a huge chunk of frames requires a lot of memory, for example, the resnet computes dense feature maps for every frame and store them in the GPU memory. If you are tracking multiple points simultaneously, that also requires allocating memory for each point during iterative updates.

There can be some memory saving tricks. For the simplicity, you could try chunk your 300 frames into smaller clips with overlapping (i.e. frame 0-125, frame 75 - 225, ...), predict point tracks, throw away the predictions around the window boundary.

Note that the model is trained only with 24 frames during training. I observed that there is around 30 frames context window that the model truly used during the 1d temporal convolution. Beyond that, there is not much big difference especially when points are predicted as visible.

Hi, thank you for the great work. I'm using TAPIR to do inference on long video as well. Can you be more specific about what you mean by "throw away the predictions around the window boundary" ? If I understand correctly you mean throw away the predictions in the overlap frames between two chunks ? Can you explain why we need to do it ? Thank you in advance! I'm sorry if the question is too obvious.

from tapnet.

yangyi02 avatar yangyi02 commented on July 28, 2024

For each frame, TAPIR needs to see enough temporal context frames to make the best possible prediction. However due to memory limitation, the boundary frames (i.e. frame 300 in your case) will not see its future context (i.e. frame 301-315), so the model will not work great around there. But still you can trust the model prediction between frame 0 and 270.

Now if you want a better prediction after 270, you can extract frames from 240 to 540 and run model prediction again. Ideally the model could give you a reliable prediction for frame 270 to 510.

Hope that explains.

from tapnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.