Giter Site home page Giter Site logo

lorax with haiku about lorax HOT 4 CLOSED

KiAkize avatar KiAkize commented on August 20, 2024
lorax with haiku

from lorax.

Comments (4)

davisyoshida avatar davisyoshida commented on August 20, 2024

I actually mostly use haiku myself.

It looks like you're trying to directly call the haiku model. Haiku models are called using model.apply, so you want to do something like this:

# Without lora
out = model.apply(params, rng, inputs)

# With lora
wrapped_apply = lorax.lora(model.apply)
out = model.apply(params, rng, inputs)

I believe your code will currently crash even if you don't use lorax.lora.

Btw it looks like n_out is the sort of thing that JAX transforms (like jit, lora, etc) won't like, so I'd recommend you bind it using a partial function application:

from functools import partial
apply_with_n_out = partial(model.apply, n_out=n_out)
lora_model = lorax.lora(apply_with_n_out)
# etc

from lorax.

KiAkize avatar KiAkize commented on August 20, 2024

Sorry for missing the part of the code related to the linear module. The updated code is as follows:

import haiku as hk
import numbers
from typing import Union, Sequence
import numpy as np

TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
                                            dtype=np.float32)


def get_initializer_scale(initializer_name, input_shape):
  """Get Initializer for weights and scale to multiply activations by."""

  if initializer_name == 'zeros':
    w_init = hk.initializers.Constant(0.0)
  else:
    # fan-in scaling
    scale = 1.
    for channel_dim in input_shape:
      scale /= channel_dim
    if initializer_name == 'relu':
      scale *= 2

    noise_scale = scale

    stddev = np.sqrt(noise_scale)
    # Adjust stddev for truncation.
    stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
    w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)

  return w_init



class Linear(hk.Module):
  """Protein folding specific Linear module.

  This differs from the standard Haiku Linear in a few ways:
    * It supports inputs and outputs of arbitrary rank
    * Initializers are specified by strings
  """

  def __init__(self,
               num_output: Union[int, Sequence[int]],
               initializer: str = 'linear',
               num_input_dims: int = 1,
               use_bias: bool = True,
               bias_init: float = 0.,
               precision = None,
               name: str = 'linear'):
    """Constructs Linear Module.

    Args:
      num_output: Number of output channels. Can be tuple when outputting
          multiple dimensions.
      initializer: What initializer to use, should be one of {'linear', 'relu',
        'zeros'}
      num_input_dims: Number of dimensions from the end to project.
      use_bias: Whether to include trainable bias
      bias_init: Value used to initialize bias.
      precision: What precision to use for matrix multiplication, defaults
        to None.
      name: Name of module, used for name scopes.
    """
    super().__init__(name=name)
    if isinstance(num_output, numbers.Integral):
      self.output_shape = (num_output,)
    else:
      self.output_shape = tuple(num_output)
    self.initializer = initializer
    self.use_bias = use_bias
    self.bias_init = bias_init
    self.num_input_dims = num_input_dims
    self.num_output_dims = len(self.output_shape)
    self.precision = precision

  def __call__(self, inputs):
    """Connects Module.

    Args:
      inputs: Tensor with at least num_input_dims dimensions.

    Returns:
      output of shape [...] + num_output.
    """

    num_input_dims = self.num_input_dims

    if self.num_input_dims > 0:
      in_shape = inputs.shape[-self.num_input_dims:]
    else:
      in_shape = ()

    weight_init = get_initializer_scale(self.initializer, in_shape)

    in_letters = 'abcde'[:self.num_input_dims]
    out_letters = 'hijkl'[:self.num_output_dims]

    weight_shape = in_shape + self.output_shape
    weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
                               weight_init)

    equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'

    output = jnp.einsum(equation, inputs, weights, precision=self.precision)

    if self.use_bias:
      bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
                              hk.initializers.Constant(self.bias_init))
      output += bias

    return output


from lorax.

davisyoshida avatar davisyoshida commented on August 20, 2024

Was it still not working after you fixed the problem I indicated?

from lorax.

KiAkize avatar KiAkize commented on August 20, 2024

Was it still not working after you fixed the problem I indicated?

Thanks a lot. The code works now:)

from lorax.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.