Giter Site home page Giter Site logo

walln / scratch Goto Github PK

View Code? Open in Web Editor NEW
3.0 1.0 0.0 525 KB

High quality and readable implementations of machine learning architectures and ideas, accelerator kernels, and other things from scratch

Home Page: https://walln.dev/projects/scratch

Python 99.76% Shell 0.24%

scratch's Introduction

Scratch

Implementations of machine learning methods and deep learning models for my own reference.

scratch's People

Contributors

walln avatar

Stargazers

Siliang Qin avatar Clément POIRET avatar Roman avatar

Watchers

 avatar

scratch's Issues

RMSNorm used in Mamba2 Implementation

Hi @walln , this project looks really great. I'm really surprised when I searched for an implementation of Mamba2, your implementation looks very clear compared to others. However, when I try to print the hidden stated output of the model, I find that they are all nans, after investigating the code, I think this may be due to the different arguments of RMSNorm of nnx.RMSNorm and that used in the original mamba2 implementation.

class RMSNorm(torch.nn.Module):

    def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        if dropout_p > 0.0:
            self.drop = torch.nn.Dropout(dropout_p)
        else:
            self.drop = None
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
        return rms_norm_fn(
            x,
            self.weight,
            self.bias,
            residual=residual,
            eps=self.eps,
            dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
        )

class RMSNorm(Module):
  """RMS Layer normalization (https://arxiv.org/abs/1910.07467).

  RMSNorm normalizes the activations of the layer for each given example in a
  batch independently, rather than across a batch like Batch Normalization.
  Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
  standard deviation of the activations, RMSNorm does not re-center at all
  and instead normalizes by the root mean square of the activations.

  Example usage::

    >>> from flax import nnx
    >>> import jax

    >>> x = jax.random.normal(jax.random.key(0), (5, 6))
    >>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

    >>> nnx.state(layer)
    State({
      'scale': VariableState(
        type=Param,
        value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
      )
    })

    >>> y = layer(x)

  Attributes:
    num_features: the number of input features.
    epsilon: A small float added to variance to avoid dividing by zero.
    dtype: the dtype of the result (default: infer from input and params).
    param_dtype: the dtype passed to parameter initializers (default: float32).
    use_scale: If True, multiply by scale (gamma). When the next layer is linear
        (also e.g. nn.relu), this can be disabled since the scaling will be done
        by the next layer.
    scale_init: Initializer for scale, by default, one.
    reduction_axes: Axes for computing normalization statistics.
    feature_axes: Feature axes for learned bias and scaling.
    axis_name: the axis name used to combine batch statistics from multiple
        devices. See ``jax.pmap`` for a description of axis names (default: None).
        This is only needed if the model is subdivided across devices, i.e. the
        array being normalized is sharded across devices within a pmap.
    axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
        example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
        the examples on the first two and last two devices. See ``jax.lax.psum``
        for more details.
    use_fast_variance: If true, use a faster, but less numerically stable,
        calculation for the variance.
    rngs: rng key.
  """

  def __init__(
    self,
    num_features: int,
    *,
    epsilon: float = 1e-6,
    dtype: tp.Optional[Dtype] = None,
    param_dtype: Dtype = jnp.float32,
    use_scale: bool = True,
    scale_init: Initializer = initializers.ones,
    reduction_axes: Axes = -1,
    feature_axes: Axes = -1,
    axis_name: tp.Optional[str] = None,
    axis_index_groups: tp.Any = None,
    use_fast_variance: bool = True,
    rngs: rnglib.Rngs,
  ):
    feature_shape = (num_features,)

    if use_scale:
      key = rngs.params()
      self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
    else:
      self.scale = nnx.Param(None)

    self.num_features = num_features
    self.epsilon = epsilon
    self.dtype = dtype
    self.param_dtype = param_dtype
    self.use_scale = use_scale
    self.scale_init = scale_init
    self.reduction_axes = reduction_axes
    self.feature_axes = feature_axes
    self.axis_name = axis_name
    self.axis_index_groups = axis_index_groups
    self.use_fast_variance = use_fast_variance

  def __call__(self, x, mask: tp.Optional[jax.Array] = None):
    """Applies layer normalization on the input.

    Args:
      x: the inputs

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    mean, var = _compute_stats(
      x,
      self.reduction_axes,
      self.dtype,
      self.axis_name,
      self.axis_index_groups,
      use_mean=False,
      use_fast_variance=self.use_fast_variance,
      mask=mask,
    )

    return _normalize(
      x,
      mean,
      var,
      self.scale.value,
      None,
      self.reduction_axes,
      self.feature_axes,
      self.dtype,
      self.epsilon,
    )

Can you have a look at it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.