Source code for jaxkan.layers.Dense

import jax.numpy as jnp

from flax import nnx

from typing import Union


[docs] class DenseLayer(nnx.Module): """ Dense layer with random weight factorization (RWF) for use in MLP architectures. Note: This is not a KAN layer, but a standard MLP building block used in advanced KAN architectures like KKAN (see jaxkan.models module). Attributes: g (nnx.Param): Scale factor vector of shape (n_out,) from the RWF reparameterization. v (nnx.Param): Direction matrix of shape (n_in, n_out) from the RWF reparameterization. b (nnx.Param or None): Bias vector of shape (n_out,), or None if add_bias is False. activation (callable or None): Activation function applied after the linear transformation, or None. """
[docs] def __init__(self, n_in: int, n_out: int, activation = None, RWF: dict = {"mean": 1.0, "std": 0.1}, add_bias: bool = True, seed: int = 42): """ Initializes a Dense layer with RWF. Args: n_in (int): Number of input features. n_out (int): Number of output features. activation (callable, optional): Activation function applied after the linear transformation. Defaults to None. RWF (dict, optional): Dictionary with keys ``'mean'`` and ``'std'`` controlling the log-normal scale of the RWF reparameterization. Defaults to ``{"mean": 1.0, "std": 0.1}``. add_bias (bool, optional): Whether to include a learnable bias term. Defaults to True. seed (int, optional): Random seed for parameter initialization. Defaults to 42. Example: >>> layer = DenseLayer(n_in=64, n_out=32, add_bias=True, seed=42) """ # Setup nnx rngs rngs = nnx.Rngs(seed) # Initialize kernel via RWF - shape (n_in, n_out) mu, sigma = RWF["mean"], RWF["std"] # Glorot Initialization stddev = jnp.sqrt(2.0/(n_in + n_out)) # Weight matrix with shape (n_in, n_out) w = nnx.initializers.normal(stddev=stddev)( rngs.params(), (n_in, n_out), jnp.float32 ) # Reparameterization towards g, v g = nnx.initializers.normal(stddev=sigma)( rngs.params(), (n_out,), jnp.float32 ) g += mu g = jnp.exp(g) # shape (n_out,) v = w/g # shape (n_in, n_out) self.g = nnx.Param(g) self.v = nnx.Param(v) # Initialize bias - shape (n_out,) if add_bias: self.b = nnx.Param(jnp.zeros((n_out,))) else: self.b = None self.activation = activation
[docs] def __call__(self, x): """ Applies the dense layer to the input. Args: x (jnp.ndarray): Input array of shape (batch, n_in). Returns: jnp.ndarray: Output array of shape (batch, n_out). Example: >>> layer = DenseLayer(n_in=4, n_out=2) >>> x = jnp.ones((3, 4)) >>> y = layer(x) # shape: (3, 2) """ # Reconstruct kernel g, v = self.g[...], self.v[...] kernel = g * v # Apply kernel and bias y = jnp.dot(x, kernel) if self.b is not None: y = y + self.b[...] if self.activation is not None: y = self.activation(y) return y