Giter Site home page Giter Site logo

Predicting LoRA weights about lorax HOT 34 CLOSED

PuR3Luck avatar PuR3Luck commented on August 20, 2024
Predicting LoRA weights

from lorax.

Comments (34)

davisyoshida avatar davisyoshida commented on August 20, 2024 1

Here's some example code showing how I'd handle it. The rough outline of the strategy is:

  1. Initialize weights for both your model and the hypernetwork (using nn.compact lets the hypernetwork do input-shape-dependent initialization)
  2. Delete the unneeded weights for the outer model (if this is costing too much GPU memory, you can do some workarounds to avoid actually initializing these weights that get created then deleted)
  3. Define a joint_call function which re-populates the weights for the model using a combination of weight sharing and calls to the hypernetwork, then calls the model using these parameters. This function will be compatible with JAX transformations like grad, vmap, and jit.
import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp


class MyFlaxNetwork(nn.Module):
    """Example network"""
    def setup(self):
        self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
        self.out_proj = nn.Dense(1)

    def __call__(self, x):
        for block in self.blocks:
            x = jax.nn.relu(block['a'](x) + block['b'](x))

        return self.out_proj(x)

class LoraMakerNetwork(nn.Module):
    @nn.compact # have to use this so our parameters can depend on the input shape
    def __call__(self, W, *whatever_other_inputs_you_want):
        M, N = W.shape
        k = 4
        a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
        b = jax.random.normal(jax.random.PRNGKey(1), (M, k))

        some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
        return a, b

def main():
    model = MyFlaxNetwork()
    params = model.init(jax.random.PRNGKey(0), jnp.ones(64))

    for i in range(1, 5):
        # Delete the extra params we don't care about
        # This step might need to change quite a bit depending on
        # what exactly you want to share or keep separate
        del params['params'][f'blocks_{i}_a']['kernel']
        del params['params'][f'blocks_{i}_b']['kernel']

    lora_model = LoraMakerNetwork()
    lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])

    @jax.jit
    def joint_call(params, lora_model_params, input_data):
        # Copy the params tree so we can mutate it
        # This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
        modified_params = jax.tree_map(lambda x: x, params)

        # Step 1: Overwrite the original param tree with LoraWeight instances
        for i in range(5):
            for k in ['a', 'b']:
                shared_param_name = f'blocks_0_{k}'
                write_params_name = f'blocks_{i}_{k}'
                w_shared = params['params'][shared_param_name]['kernel']

                # This will do the same thing for every layer, but presumably you'll be passing some other inputs
                a, b = lora_model.apply(lora_model_params, w_shared)
                lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
                modified_params['params'][write_params_name]['kernel'] = lora_weight

        # Step 2: Run model using modified params tree
        wrapped_model = lorax.lora(model.apply)
        return wrapped_model(modified_params, input_data)

    inp = jax.random.normal(jax.random.PRNGKey(0), (64,))

    print(joint_call(params, lora_model_params, inp))

if __name__ == '__main__':
    main()

Feel free to let me know if you have any questions or if this won't work.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I am experimenting with using one transformer block and then switching out loras for each application of the transformer block, hence I would like to train both neural networks at the same time

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

@PuR3Luck The LoraWeight class is a pytree, so you can return it from jitted functions. In this case even that shouldn't be necessary, I think the easiest way would be something like this:

def the_lora_weights_model(params, inputs):
    # do whatever you're doing here to produce A and B
    return A, B
    
def original_model(W, inputs):
    # actual model logic
    return output

lora_model = lorax.lora(original_model) # this makes it so that the model can handle LoraWeight arguments

def combined_model(W, inner_params, inputs):
    A, B = the_lora_weights_model(...) # invoke this however it's supposed to be called
    
    lora_weight = LoraWeight(w=W, a=A, b=B)
    output = lora_model(lora_weight, ...) # invoke the model like normal except pass the lora weight

If you give me some more details or code for your setup I can give you something more customized.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Thanks for the help!

What I am planning to do is to replace all of the transformer blocks of an autoregressive transformer written in Flax with just one transformer block, that is updated with LoRAs parameterised by a neural hypernetwork. The hypernetwork is given a signal for the depth of the layer and possibly the most recent token generated. This may be able to let me perform inference time quality-throughput tradeoffs.

I think this might be viable as I think One Wide Feedforward is All You Need (https://arxiv.org/pdf/2309.01826.pdf) showed that replacing all ffns with just one common ffn can work. I hypothesized combining that combining this approach (common ffns and maybe self attention) with LoRAs might allow me to regain the lost accuracy while potentially overfitting less due to the lower parameter count.

Maybe it might even be faster compared to larger models due to the smaller parameter count and being more compute bound and less memory bound compared to larger models

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

On a separate note, if I want to train both the transformer block and the hypernetwork at the same time, I should not use the wrap_optimiser() function right?

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

I tried the "shared weights but different LoRAs" a few months ago. I found that the parameter efficiency of dense weights was better (I had to kick the lora dimension up so high it wasn't actually reducing the number of parameters). Your hypernetwork idea is interesting, maybe it will help.

I should not use the wrap_optimiser() function right?

In your case it sounds like all the parameters are trainable, so it should be unecessary.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

In your case did you pretrain from scratch or factorise out LoRA matrices from a pretrained model?

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

I trained from scratch. I was doing language modeling on penn treebank. I definitely can't say I did a thorough enough experiment to be sure that it doesn't work, since I only spent a couple hours on it.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Would it be possible to still do it if I wanted to write all my networks in Flax (as I see you are manually passing in all the parameters)? Also, do you see an elegant way of using the pytree such that my hyper network knows the dimensions of its outputs on initialisation such that I can easily pass the output (As and Bs for all parts of the main network) of the hypernetwork into LoRAWeight?

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Thanks for the prompt reply!

Do you think that there is a way to wrap the main function in a flax module? I think that wrapping it in a flax module would make testing and initialising the model more convenient. My only worry is that in the flax module since I dont think the init method is not called, I am not sure of a way to del the parameters

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I am leaving a simplified version of my model code below (i am not fully certain on the implementation of the LoRATransformer class without the init method)

class TransformerBlock(nn.Module):
  d_model: int
  num_heads: int
  @nn.compact
  def __call__(self,x):
    attn = nn.SelfAttention(num_heads=self.num_heads,name="self_attn")(x)
    x = attn + x
    x = nn.LayerNorm()(x)
    x = nn.Sequential([
      nn.Dense(features=4*self.d_model,name="ffn_1"),
      nn.GELU(),
      nn.Dense(features=self.d_model,name="ffn_2"),
    ])(x)
    x = nn.Dropout(rate=0.1)(x)
    x = x + attn
    x = nn.LayerNorm()(x)
    return x

class Transformer(nn.Module):
  depth: int
  d_model: int
  num_heads: int

  def setup(self):
    self.blocks = [TransformerBlock(d_model=self.d_model,num_heads=self.num_heads) for _ in range(self.depth)]
    self.output = nn.Dense(features=1)

  def __call__(self,x):
    for block in self.blocks:
      x = block(x)
    x = self.output(x)
    return x

class LoRATransformer(nn.Module):
  depth: int
  lora_rank: int
  d_model: int
  num_heads: int

  def setup(self):
    network = Transformer(depth=self.depth,d_model=self.d_model,num_heads=self.num_heads)
    lora_hypernetwork = LoRA_Hypernetwork(lora_rank=self.lora_rank)

    
    for i in range(1,self.depth+1):
    

  def __call__(self,x):
    raise NotImplementedError

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

So I think you probably don't want the LoRATransformer to be a Module, since you want to be able to manipulate the parameters of the Transformer module. Inside the context of a flax module, the parameters are hidden from you. That's why I implemented the joint_call function.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

So am I right to interpret your response as I cannot implement LoRATransformer as a flax module?

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

You probably can get it working that way, but it will be harder.

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

Oh whoops didn't mean to close.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

@davisyoshida Just to confirm, but I use the joint call method to perform the training right, and this applying a loss through the joint call function would optimise both the transformer block and hypernetwork parameters right?

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Also do I have to initialize a separate hypernetwork for each matrix that has a unique shape? I presume so

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

And also how should I modify your code if I want to add dropout, espescially the line

wrapped_model = lorax.lora(model.apply) 

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

I use the joint call method to perform the training right

Yes. To differentiate your loss function with respect to both sets of parameters, you can either pass them together in a tuple, or use the argnums param to jax.grad.

Also do I have to initialize a separate hypernetwork for each matrix that has a unique shape? I presume so

I think this depends on your architecture, not the particular implementation. If you have some architecture that can handle multiple output shapes, it should be fine to use that.

how should I modify your code if I want to add dropout

Usually dropout goes in the model code, you shouldn't need to add anything extra to the code I supplied. If you have something else in mind give me a few more details on how you want to apply dropout and I can give some advice.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Following the implementation of dropout from the flax linen docs, dropout is enabled by passing a training boolean, combined with a prng key for the dropout in the apply method, so I was wondering how to modify the model.apply method as the lorax.lora function wraps model.apply. I suspect I should modify the code to be wrapped_model = lorax.lora(model) then follow the dropout guide, and use wrapped_model.apply instead, with the training flag and prng key

https://flax.readthedocs.io/en/latest/guides/training_techniques/dropout.html

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

@PuR3Luck The lora() transform knows about functions, not models, so you won't be able to access attributes like .apply after using it. If it doesn't work to call wrapped_model(modified_params, input_data, training=True), then you could solve that with a partial function application:

from functools import partial

apply_with_dropout = partial(model.apply, training=True)
wrapped_model = lorax.lora(apply_with_dropout)

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Just checking my final code should look something like this right?

import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
from functools import partial


class MyFlaxNetwork(nn.Module):
    """Example network"""
    def setup(self):
        self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
        self.out_proj = nn.Dense(1)

    def __call__(self, x):
        for block in self.blocks:
            x = jax.nn.relu(block['a'](x) + block['b'](x))

        return self.out_proj(x)

class LoraMakerNetwork(nn.Module):
    @nn.compact # have to use this so our parameters can depend on the input shape
    def __call__(self, W, *whatever_other_inputs_you_want):
        M, N = W.shape
        k = 4
        a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
        b = jax.random.normal(jax.random.PRNGKey(1), (M, k))

        some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
        return a, b

def main():
    model = MyFlaxNetwork()
    params = model.init(jax.random.PRNGKey(0), jnp.ones(64))

    for i in range(1, 5):
        # Delete the extra params we don't care about
        # This step might need to change quite a bit depending on
        # what exactly you want to share or keep separate
        del params['params'][f'blocks_{i}_a']['kernel']
        del params['params'][f'blocks_{i}_b']['kernel']

    lora_model = LoraMakerNetwork()
    lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])

    @jax.jit
    def joint_call(params, lora_model_params, input_data):
        # Copy the params tree so we can mutate it
        # This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
        modified_params = jax.tree_map(lambda x: x, params)

        # Step 1: Overwrite the original param tree with LoraWeight instances
        for i in range(5):
            for k in ['a', 'b']:
                shared_param_name = f'blocks_0_{k}'
                write_params_name = f'blocks_{i}_{k}'
                w_shared = params['params'][shared_param_name]['kernel']

                # This will do the same thing for every layer, but presumably you'll be passing some other inputs
                a, b = lora_model.apply(lora_model_params, w_shared)
                lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
                modified_params['params'][write_params_name]['kernel'] = lora_weight
       
        apply_fn  = partial(model.apply, training = True, rngs = {"dropout":jax.random.PRNG(0)})

        # Step 2: Run model using modified params tree
        wrapped_model = lorax.lora(apply_fn)
        return wrapped_model(modified_params, input_data)
    
    all_params = (params_0_a, params_0_b, lora_model_params)

    optimiser = optax.adam(1e-3)

    opt_state = optimiser.init(all_params)
    
    @jax.jit
    def update_fn(params, opt_state, joint_call, x):
        grad_fn = jax.value_and_grad(joint_call)
        loss, grad = grad_fn(params, x)
        updates, new_opt_state = optimizer.update(grad, opt_state, params=params)
        updated_params = optax.apply_updates(params, updates)
        return loss, new_opt_state, updated_params

    for i in range(epochs):
        loss, opt_state, params = update_fn(loss, opt_state, params)


if __name__ == '__main__':
    main()

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

Well this doesn't actually have a loss function or targets to use for joint_call, but other than that the layout seems roughly correct. You don't need to pass joint_call as an argument to the update function, it's already accessible because they're in the same scope.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I thought that functions being jax.jit must not use global variables and functions though?

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

You only need to avoid that if their values will change. For stuff like a function you're only going to assign once there's no issue.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I didn't know that. Thanks for sharing!

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I resolved the missing loss function and joint call targets, so the code should look like this right?

import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
from functools import partial


class MyFlaxNetwork(nn.Module):
    """Example network"""
    def setup(self):
        self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
        self.out_proj = nn.Dense(1)

    def __call__(self, x):
        for block in self.blocks:
            x = jax.nn.relu(block['a'](x) + block['b'](x))

        return self.out_proj(x)

class LoraMakerNetwork(nn.Module):
    @nn.compact # have to use this so our parameters can depend on the input shape
    def __call__(self, W, *whatever_other_inputs_you_want):
        M, N = W.shape
        k = 4
        a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
        b = jax.random.normal(jax.random.PRNGKey(1), (M, k))

        some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
        return a, b

def main():
    model = MyFlaxNetwork()
    params = model.init(jax.random.PRNGKey(0), jnp.ones(64))

    for i in range(1, 5):
        # Delete the extra params we don't care about
        # This step might need to change quite a bit depending on
        # what exactly you want to share or keep separate
        del params['params'][f'blocks_{i}_a']['kernel']
        del params['params'][f'blocks_{i}_b']['kernel']

    lora_model = LoraMakerNetwork()
    lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])

    @jax.jit
    def joint_call(params, lora_model_params, input_data):
        # Copy the params tree so we can mutate it
        # This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
        modified_params = jax.tree_map(lambda x: x, params)

        # Step 1: Overwrite the original param tree with LoraWeight instances
        for i in range(5):
            for k in ['a', 'b']:
                shared_param_name = f'blocks_0_{k}'
                write_params_name = f'blocks_{i}_{k}'
                w_shared = params['params'][shared_param_name]['kernel']

                # This will do the same thing for every layer, but presumably you'll be passing some other inputs
                a, b = lora_model.apply(lora_model_params, w_shared)
                lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
                modified_params['params'][write_params_name]['kernel'] = lora_weight
       
        apply_fn  = partial(model.apply, training = True, rngs = {"dropout":jax.random.PRNG(0)})

        # Step 2: Run model using modified params tree
        wrapped_model = lorax.lora(apply_fn)
        return wrapped_model(modified_params, input_data)
    
    all_params = (params, lora_model_params)

    optimiser = optax.adam(1e-3)

    opt_state = optimiser.init(all_params)
    
    def loss_fn(params,data):
        output = joint_call(
            params[0],
            params[1],
            data
            )
        # Calculate loss
        return loss
    
    @jax.jit
    def update_fn(params, opt_state, data):
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(params, data)
        updates, new_opt_state = optimizer.update(grad, opt_state, params=params)
        updated_params = optax.apply_updates(params, updates)
        return loss, new_opt_state, updated_params

    for i in range(epochs):
        loss, opt_state, params = update_fn(params, opt_state, data)
        print(f"Loss:{loss}")
        


if __name__ == '__main__':
    main()

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

Looks like it should work yeah

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

Closing this, feel free to let me know if there are any more problems though.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I am currently playing around with training the Transformer. I am also now wondering how I could apply this to convolutions as the kernel for the convolution has 3 dimensions?

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

I am applying a 1D convolution over a sequence with dimensions (seq_len, dimension) that has multiple out_features so the kernel has a shape with 3 dimensions. I saw in the code for lorax that loraweight has a mechanism to handle convolutions. Could you provide some advice?

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

@davisyoshida The network seems to very frequently output NaN values for "deep networks" around 16 or so layers as in the output of the network is NaN, this behaviour is weird to me as I didnt really observe this behaviour when using normal dense models but seems also have this issue if I scale model width, do you have some suggestions to fix this behaviour? In general it seems extremely sensitive to the hyperparameters. I have tried different weight initialisation but it does not seem to help much.

from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

@PuR3Luck This sounds like something more to do with the hypernetwork part than the Lora part. I don't have any advice for you since I haven't used hypernetworks.

from lorax.

PuR3Luck avatar PuR3Luck commented on August 20, 2024

Ok thanks for all the help!

from lorax.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.