Giter Site home page Giter Site logo

Comments (4)

CHYjeremy avatar CHYjeremy commented on July 28, 2024 1

thank you very much @m43

from tapnet.

m43 avatar m43 commented on July 28, 2024

To select a specific GPU, you can specify the device when using jax.jit. Here's an example:

gpus = jax.devices('gpu')
device = gpus[1] # If you want to use "cuda:1"
model_apply = jax.jit(model.apply, device=device)

To confirm that you are utilizing the GPU, execute nvidia-smi in the terminal. It will display GPU usage statistics and the associated processes. For a real-time monitoring tool, consider using nvtop which provides a "top-like" interface.

As for avoiding recompilation, JAX compiles the function and caches the XLA code when a jax.jit function is called for the first time. From what I understand, subsequent calls use the cached version unless there are changes in arguments labeled as static in the call to jax.jit.

Lastly, to prevent JAX from preallocating substantial GPU memory, consider setting this environment variable:

export XLA_PYTHON_CLIENT_PREALLOCATE=false

Hope this helps!

from tapnet.

CHYjeremy avatar CHYjeremy commented on July 28, 2024

Thank you very much for the amazing answer.

but let me follow up with your suggestion on setting environment variable, may i ask, to what extent will it hurt the performance(speed) of the algorithm when you prevent jax from preallocating much gpu memory?

thanks.

from tapnet.

m43 avatar m43 commented on July 28, 2024

Allocating memory on-the-fly might not have a significant impact on speed compared to preallocating, though this isn't explicitly detailed in the documentation. Disabling preallocation can reduce overall memory usage, especially if JAX doesn't fully utilize the preallocated memory. This becomes particularly relevant if TensorFlow or PyTorch models run concurrently on the same GPU. However, without preallocation, there's an increased risk of memory fragmentation, which could lead to OOM issues for JAX programs that need to consume most of the GPU memory. You can find more details on this topic here. If you're not facing any OOM issues, the default memory settings might be well-suited for your needs.

from tapnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.