Giter Site home page Giter Site logo

Type precision issue in BoxOSQP about jaxopt HOT 10 OPEN

jewillco avatar jewillco commented on July 26, 2024
Type precision issue in BoxOSQP

from jaxopt.

Comments (10)

Algue-Rythme avatar Algue-Rythme commented on July 26, 2024

Can you try to promote the params_ineq=(-1, 1) tuple to float32 by default? Tell me how it's going.

from jaxopt.

jewillco avatar jewillco commented on July 26, 2024

I tried float32 and float64 there (using the jnp.float32(1) syntax):

float32: TypeError: true_fun and false_fun output must have identical types, got ('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).

float64: TypeError: body_fun output and input must have identical types, got ('DIFFERENT ShapedArray(float64[30]) vs. ShapedArray(float32[30])', 'ShapedArray(float64[30])', 'ShapedArray(float64[])', 'ShapedArray(float64[30])', 'ShapedArray(int64[], weak_type=True)').

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on July 26, 2024

Would you mind sharing your minimal (not) working example in Colab? Thanks in advance.

from jaxopt.

jewillco avatar jewillco commented on July 26, 2024
optimizer = jaxopt.BoxOSQP()
optimizer.run(
    params_obj=(
        jnp.eye(30, dtype=jnp.float32),
        jnp.ones((30,), dtype=jnp.float32),
    ),
    params_eq=jnp.ones((1, 30), dtype=jnp.float32),
    params_ineq=(jnp.float32(-1), jnp.float32(1)),
)

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on July 26, 2024

You did not gave me a Colab link. So, I copy/pasted the code in Colab, add a few imports, and in Colab, it works! There are no errors... which version are you using for jax/jaxopt/python? Are you using a GPU?

from jaxopt.

jewillco avatar jewillco commented on July 26, 2024

I am using a TPU and my Colab has a large number of other things in it so I can't share it. Did you turn on float64 in JAX? That is the one thing that might be different from the snippet I posted.

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on July 26, 2024

I did not turn on float64 on my initial test, check by yourself! I tested in float32 in CPU / GPU / TPU in Colab; it works.

In float64 enabled, and on a TPU, I get: XlaRuntimeError: INVALID_ARGUMENT: 64-bit data types are not yet supported on the TPU driver API. Convert inputs to float32/int32_t before using. which is the expected behavior for TPUs anyway, because they don't intent to leverage float64 arithmetic. Since you don't encounter this error, I wonder if you enabled the TPU in Colab with jax.tools.colab_tpu.setup_tpu().

The error you gave me arises when mixing float32 objects (in your call) with float64 objects that are allocated by default in BoxOSQP, on CPU (for example after failing to enable the TPU). This is also an expected behavior, because Jax policy is to prevent aggressive type promotion. However, if you force everything to be in float64, it works! Look here

my Colab has a large number of other things in it so I can't share it

Well, I am not asking for your whole work, just a minimal working example that reproduces the issue.

That is the one thing that might be different from the snippet I posted.

This is what I meant when I said "share a Colab link": it is not easy to infer what you did on your environement without details, the code you gave me was clearly unsufficient to understand what is really going on. As you can see, on Colab I can trigger different types of errors by juggling with types, environements, initialization at startup, and I consider none of these behaviors as a bug.

from jaxopt.

jakevdp avatar jakevdp commented on July 26, 2024

Hi - JAX developer here – it looks like you're using Colab TPU; as of this writing (October 2023) Colab only provides very old TPU hardware, and is only compatible with a very old JAX version. I would not recommend running JAX on Colab TPU until this changes (but note that Colab CPU and GPU are fine). I believe this issue is fixed on more modern TPU architectures.

If you'd like to use modern TPUs in a free public notebook, I'd suggest taking a look at Kaggle, which provides more up-to-date TPU runtimes.

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on July 26, 2024

Thanks for the heads up.

@jewillco: could you clarify your intent with this code? If my understanding is correct, you need:

  • a TPU for performance
  • float64 precision enabled by default for some reason (do you expect these computations to run on TPU?)
  • but you want the boxOSQP solver to run in float32 by default anyway?

from jaxopt.

jewillco avatar jewillco commented on July 26, 2024

I want #1 and #2, at least with the option to run other parts of my code in float64 on the TPU (which is semi-supported). I would like BoxOSQP to run in either float32 or float64 depending on what inputs I give it. It turns out that it does work with all float64 inputs to the solver; it still produces NaNs on my problem but that's a different issue.

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.