Comments (4)
thank you very much @m43
from tapnet.
To select a specific GPU, you can specify the device when using jax.jit
. Here's an example:
gpus = jax.devices('gpu')
device = gpus[1] # If you want to use "cuda:1"
model_apply = jax.jit(model.apply, device=device)
To confirm that you are utilizing the GPU, execute nvidia-smi
in the terminal. It will display GPU usage statistics and the associated processes. For a real-time monitoring tool, consider using nvtop
which provides a "top-like" interface.
As for avoiding recompilation, JAX compiles the function and caches the XLA code when a jax.jit
function is called for the first time. From what I understand, subsequent calls use the cached version unless there are changes in arguments labeled as static
in the call to jax.jit
.
Lastly, to prevent JAX from preallocating substantial GPU memory, consider setting this environment variable:
export XLA_PYTHON_CLIENT_PREALLOCATE=false
Hope this helps!
from tapnet.
Thank you very much for the amazing answer.
but let me follow up with your suggestion on setting environment variable, may i ask, to what extent will it hurt the performance(speed) of the algorithm when you prevent jax from preallocating much gpu memory?
thanks.
from tapnet.
Allocating memory on-the-fly might not have a significant impact on speed compared to preallocating, though this isn't explicitly detailed in the documentation. Disabling preallocation can reduce overall memory usage, especially if JAX doesn't fully utilize the preallocated memory. This becomes particularly relevant if TensorFlow or PyTorch models run concurrently on the same GPU. However, without preallocation, there's an increased risk of memory fragmentation, which could lead to OOM issues for JAX programs that need to consume most of the GPU memory. You can find more details on this topic here. If you're not facing any OOM issues, the default memory settings might be well-suited for your needs.
from tapnet.
Related Issues (20)
- Re: architecture of ResNet used for track initialization HOT 2
- _pickle.UnpicklingError: Failed to interpret file '../checkpoint/checkpoint.npy' as a pickle HOT 2
- is there an open version of control code of robotap HOT 1
- OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input. HOT 5
- ValueError: All `hk.Module`s must be initialized inside an `hk.transform` HOT 12
- ValueError: converting frame count is not supported. HOT 8
- robotap's query points selection question HOT 5
- Torchscript compatibility HOT 6
- Has anyone implemented it with tensorrt? HOT 1
- None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms HOT 2
- Pytorch <2.1.0 can't load the checkpoints correctly HOT 1
- Pretrained Weights for Pytorch Version of Online Tapir/BootsTapir HOT 2
- IndexError: boolean index did not match indexed array along dimension 1; dimension is 256 but corresponding boolean dimension is 990 HOT 2
- Training TAPIR PyTorch version script? HOT 7
- BootsTAP Training Dataset HOT 1
- `plot_tracks_v2` has bug when plotting with `trackgroup` argument. HOT 2
- KeyError: 'global_step' When I load the weight of TAPIR HOT 5
- CUDA out of memory issue when using PyTorch weights instead of JAX weights. HOT 2
- pytorch version TAPIR 's training file HOT 1
- Annotation Tool for TAP-VID HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tapnet.