Giter Site home page Giter Site logo

Equinox support? about lorax HOT 3 CLOSED

patrick-kidger avatar patrick-kidger commented on August 20, 2024
Equinox support?

from lorax.

Comments (3)

davisyoshida avatar davisyoshida commented on August 20, 2024

@patrick-kidger It's supposed to work out of the box but actually I found a bug in my general case handling of dot_general. So I fixed that. The only thing breaking that I haven't decided how to fix is the formatting of pytree paths to parameters when constructing a lora spec. You can just manually do it though, and then the following works:

import equinox as eqx
import jax
import jax.numpy as jnp
import optax

import lorax


class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

@lorax.lora
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

@jax.jit
@jax.value_and_grad
def split_params_loss(tune_params, freeze_params, x, y):
    return loss_fn((freeze_params, tune_params), x, y)

batch_size, in_size, out_size = 32, 128, 64
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
model = Linear(in_size, out_size, key=k1)
x = jax.random.normal(k2, (batch_size, in_size))
y = jax.random.normal(k3, (batch_size, out_size))


lora_spec = jax.tree_map(lambda x: 16 if len(x.shape) > 1 else lorax.LORA_FULL , model)
freeze_params, tune_params = lorax.init_lora(
    param_tree=model, 
    spec=lora_spec,
    rng=k4
)

lora_fn = jax.jit(jax.grad(lorax.lora(loss_fn)))

opt = optax.adam(learning_rate=1e-1)
opt_state = opt.init(tune_params)

for i in range(100):
    loss, grad = split_params_loss(tune_params, freeze_params, x, y)
    updates, opt_state = opt.update(grad, opt_state)
    tune_params = optax.apply_updates(tune_params, updates)
    print(f'{i}: {loss:.3e}')

Well it works locally, but I'll be pushing the fixed version shortly.

from lorax.

patrick-kidger avatar patrick-kidger commented on August 20, 2024

Oh nice -- that's really cool. Thank you for looking into this.

For specifying paths to parameters: I've encountered this issue before. IMO the nicest way to do it is to use a lambda function, e.g. lambda model: model.layers[-1].weight. (This is exactly what's done in equinox.tree_at, e.g. new_model = tree_at(lambda m: m.layers[-1].weight, model, new_weight) modifies the weight at this position. Of course this idea can be used independently of Equinox.)

There is also the new keypath functionality, but this doesn't play super well with custom pytree nodes -- you have to say things like "the ith leaf" which isn't really that nice.

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

For now I just had it call str with unknown node types (I'm using jax.tree_util.tree_map_with_path), but I agree it's not super pretty, so I may do something else in the future.

I fixed the dot bug and the example above should work now, let me know if there are any problems with it.

from lorax.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.