class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if dropout_p > 0.0:
self.drop = torch.nn.Dropout(dropout_p)
else:
self.drop = None
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
class RMSNorm(Module):
"""RMS Layer normalization (https://arxiv.org/abs/1910.07467).
RMSNorm normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
standard deviation of the activations, RMSNorm does not re-center at all
and instead normalizes by the root mean square of the activations.
Example usage::
>>> from flax import nnx
>>> import jax
>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
'scale': VariableState(
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
})
>>> y = layer(x)
Attributes:
num_features: the number of input features.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
scale_init: Initializer for scale, by default, one.
reduction_axes: Axes for computing normalization statistics.
feature_axes: Feature axes for learned bias and scaling.
axis_name: the axis name used to combine batch statistics from multiple
devices. See ``jax.pmap`` for a description of axis names (default: None).
This is only needed if the model is subdivided across devices, i.e. the
array being normalized is sharded across devices within a pmap.
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
the examples on the first two and last two devices. See ``jax.lax.psum``
for more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
rngs: rng key.
"""
def __init__(
self,
num_features: int,
*,
epsilon: float = 1e-6,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
use_scale: bool = True,
scale_init: Initializer = initializers.ones,
reduction_axes: Axes = -1,
feature_axes: Axes = -1,
axis_name: tp.Optional[str] = None,
axis_index_groups: tp.Any = None,
use_fast_variance: bool = True,
rngs: rnglib.Rngs,
):
feature_shape = (num_features,)
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)
self.num_features = num_features
self.epsilon = epsilon
self.dtype = dtype
self.param_dtype = param_dtype
self.use_scale = use_scale
self.scale_init = scale_init
self.reduction_axes = reduction_axes
self.feature_axes = feature_axes
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups
self.use_fast_variance = use_fast_variance
def __call__(self, x, mask: tp.Optional[jax.Array] = None):
"""Applies layer normalization on the input.
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
mean, var = _compute_stats(
x,
self.reduction_axes,
self.dtype,
self.axis_name,
self.axis_index_groups,
use_mean=False,
use_fast_variance=self.use_fast_variance,
mask=mask,
)
return _normalize(
x,
mean,
var,
self.scale.value,
None,
self.reduction_axes,
self.feature_axes,
self.dtype,
self.epsilon,
)