import jax.numpy as jnp
from flax import nnx
[docs]
class ActLayer(nnx.Module):
"""
ActLayer implementation based on:
"Deep Learning Alternatives of the Kolmogorov Superposition Theorem"
by Leonardo Ferreira Guilhoto and Paris Perdikaris (arXiv:2410.01990)
Forward pass: ActLayer(x) = S(Λ ⊙ (β @ B(x)))
Where:
- B(x) is the basis expansion matrix with B(x)_{ij} = b_i(x_j), shape (N, d)
- β ∈ R^{m × N} are the basis coefficients
- Λ ∈ R^{m × d} are the mixing weights
- S is the row-sum function
- ⊙ is the Hadamard (element-wise) product
The k-th output is: (ActLayer(x))_k = Σ_i λ_{ki} Σ_j β_{kj} b_j(x_i)
Attributes:
rngs (nnx.Rngs):
Random number generator state.
beta (nnx.Param):
Basis coefficients, shape (n_out, N).
Lambda (nnx.Param):
Mixing weights, shape (n_out, n_in).
omega (Union[nnx.Param, jnp.array]):
Frequency parameters for basis functions (trainable or fixed).
phase (Union[nnx.Param, jnp.array]):
Phase parameters for basis functions (trainable or fixed).
"""
[docs]
def __init__(self,
n_in: int = 3, n_out: int = 4, N: int = 5, train_basis: bool = True, seed: int = 42
):
"""
Initializes an ActLayer instance.
Args:
n_in (int):
Number of layer's incoming nodes.
n_out (int):
Number of layer's outgoing nodes.
N (int):
Number of basis functions (paper recommends N=4).
train_basis (bool):
Whether the basis function parameters (omega and phase) are trainable.
seed (int):
Random key selection for initializations wherever necessary.
Example:
>>> layer = ActLayer(n_in=2, n_out=5, N=4, train_basis=True, seed=42)
"""
# Setup nnx rngs
self.rngs = nnx.Rngs(seed)
# Initialize betas - shape (n_out, N)
# Paper: std = 1/sqrt(N) for balanced initialization
std_beta = jnp.sqrt(1.0 / N)
self.beta = nnx.Param(nnx.initializers.normal(stddev=std_beta)(
self.rngs.params(), (n_out, N), jnp.float32))
# Initialize Lambdas - shape (n_out, n_in)
# Paper: std = 1/sqrt(d) for balanced initialization
std_lambda = jnp.sqrt(1.0 / n_in)
self.Lambda = nnx.Param(nnx.initializers.normal(stddev=std_lambda)(
self.rngs.params(), (n_out, n_in), jnp.float32))
# Initialize omegas (frequencies) - shape (N,)
# Paper: ω_i ~ N(0, 1), initialized from standard normal
omega_init = nnx.initializers.normal(stddev=1.0)(
self.rngs.params(), (N,), jnp.float32)
# Initialize phases - shape (N,)
# Paper: p_i initialized at 0
phase_init = jnp.zeros((N,))
# Set omega and phase as trainable (Param) or fixed (Variable) based on train_basis
if train_basis:
self.omega = nnx.Param(omega_init)
self.phase = nnx.Param(phase_init)
else:
self.omega = omega_init
self.phase = phase_init
[docs]
def basis(self, x):
"""
Compute normalized sinusoidal basis functions.
Paper Equation 11:
b_i(t) = (sin(ω_i * t + p_i) - μ(ω_i, p_i)) / σ(ω_i, p_i)
Where μ and σ are computed assuming x ~ N(0, 1):
μ(ω, p) = exp(-ω²/2) * sin(p)
σ(ω, p) = sqrt(1/2 - 1/2 * exp(-2ω²) * cos(2p) - μ²)
Args:
x (jnp.array):
Inputs, shape (batch, n_in).
Returns:
B (jnp.array):
Basis expansion matrix, shape (batch, N, n_in),
where B[b, j, i] = b_j(x_{b,i}).
Example:
>>> layer = ActLayer(n_in=2, n_out=5, N=4, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> B = layer.basis(x_batch)
"""
# x shape: (batch, n_in)
# We need B(x)_{ji} = b_j(x_i), so output shape should be (batch, N, n_in)
# Expand dimensions for broadcasting
x_expanded = x[:, None, :] # (batch, 1, n_in)
omega = self.omega[None, :, None] # (1, N, 1)
phase = self.phase[None, :, None] # (1, N, 1)
# Compute sin(ω_j * x_i + p_j) for all combinations
wx = omega * x_expanded # (batch, N, n_in)
s = jnp.sin(wx + phase) # (batch, N, n_in)
# Compute mean: μ(ω, p) = exp(-ω²/2) * sin(p)
mu = jnp.exp(-0.5 * (omega ** 2)) * jnp.sin(phase) # (1, N, 1)
# Compute std: σ(ω, p) = sqrt(1/2 - 1/2 * exp(-2ω²) * cos(2p) - μ²)
var = 0.5 * (1.0 - jnp.exp(-2.0 * (omega ** 2)) * jnp.cos(2.0 * phase)) - mu ** 2
std = jnp.sqrt(jnp.maximum(var, 1e-8)) # (1, N, 1)
# Normalize basis
eps = 1e-8
B = (s - mu) / (std + eps) # (batch, N, n_in)
return B
[docs]
def __call__(self, x):
"""
The layer's forward pass.
Paper Equation 6: ActLayer(x) = S(Λ ⊙ (β @ B(x)))
Paper Equation 9: (ActLayer(x))_k = Σ_i λ_{ki} Σ_j β_{kj} b_j(x_i)
Args:
x (jnp.array):
Inputs, shape (batch, n_in).
Returns:
y (jnp.array):
Output of the forward pass, shape (batch, n_out).
Example:
>>> layer = ActLayer(n_in=2, n_out=5, N=4, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
"""
# Calculate basis: B shape (batch, N, n_in) where B[b,j,i] = b_j(x_{b,i})
B = self.basis(x)
# Get parameters
beta = self.beta # (n_out, N)
Lambda = self.Lambda # (n_out, n_in)
# Compute inner function expansion: Φ(x) = β @ B(x)
# Φ[b, k, i] = Σ_j β[k,j] * B[b,j,i]
# This gives the inner functions φ_k(x_i) for each output k
Phi = jnp.einsum('kj,bji->bki', beta, B) # (batch, n_out, n_in)
# Apply Lambda weights and sum: y_k = Σ_i λ_{ki} * Φ[k,i]
# This is S(Λ ⊙ Φ) where S is row-sum
y = jnp.einsum('bki,ki->bk', Phi, Lambda) # (batch, n_out)
return y
[docs]
class ActNet(nnx.Module):
"""
ActNet architecture based on:
"Deep Learning Alternatives of the Kolmogorov Superposition Theorem"
by Leonardo Ferreira Guilhoto and Paris Perdikaris (arXiv:2410.01990)
Attributes:
rngs (nnx.Rngs):
Random number generator state.
add_bias (bool):
Whether to add learnable bias after each ActLayer.
omega0 (float):
Frequency multiplier for input.
use_projections (bool):
Whether to use input/output linear projections.
input_proj (nnx.Linear, optional):
Input projection layer if use_projections is True.
layers (nnx.List):
List of ActLayer instances.
output_proj (nnx.Linear, optional):
Output projection layer if use_projections is True.
biases (nnx.List, optional):
List of bias parameters if add_bias is True.
"""
[docs]
def __init__(self,
layer_dims,
N: int = 4,
add_bias: bool = True,
omega0: float = 1.0,
use_projections: bool = False,
train_basis: bool = True,
seed: int = 42):
"""
Initializes an ActNet model.
Args:
layer_dims (List[int]):
Defines the network in terms of nodes. E.g. [2,5,1] is a network with 2 layers.
N (int):
Number of basis functions per ActLayer (paper recommends N=4).
add_bias (bool):
Whether to add learnable bias after each ActLayer.
omega0 (float):
Frequency multiplier for input (paper's Appendix D.1).
use_projections (bool):
Whether to use input/output linear projections.
train_basis (bool):
Whether the basis function parameters (omega and phase) are trainable.
seed (int):
Random key selection for initializations wherever necessary.
Example:
>>> model = ActNet(layer_dims=[2, 5, 1], N=4, add_bias=True, train_basis=True, seed=42)
"""
# Setup nnx rngs
self.rngs = nnx.Rngs(seed)
self.add_bias = add_bias
self.omega0 = omega0
self.use_projections = use_projections
input_dim = layer_dims[0]
output_dim = layer_dims[-1]
# If using projections, ActLayers operate on hidden dimensions only
if use_projections and len(layer_dims) > 2:
hidden_dim = layer_dims[1]
# Input projection: input_dim -> hidden_dim
self.input_proj = nnx.Linear(input_dim, hidden_dim, rngs=self.rngs)
# ActLayers operate on hidden dimensions
self.layers = nnx.List([
ActLayer(
n_in=layer_dims[i],
n_out=layer_dims[i + 1],
N=N,
train_basis=train_basis,
seed=seed
)
for i in range(1, len(layer_dims) - 2)
])
# Output projection: hidden_dim -> output_dim
self.output_proj = nnx.Linear(layer_dims[-2], output_dim, rngs=self.rngs)
else:
# Standard mode: ActLayers for all layer transitions
self.layers = nnx.List([
ActLayer(
n_in=layer_dims[i],
n_out=layer_dims[i + 1],
N=N,
train_basis=train_basis,
seed=seed
)
for i in range(len(layer_dims) - 1)
])
if self.add_bias:
if use_projections and len(layer_dims) > 2:
# Biases for hidden layers only
self.biases = nnx.List([
nnx.Param(jnp.zeros((layer_dims[i+1],)))
for i in range(1, len(layer_dims) - 2)
])
else:
self.biases = nnx.List([
nnx.Param(jnp.zeros((dim,))) for dim in layer_dims[1:]
])
[docs]
def __call__(self, x):
"""
Equivalent to the network's forward pass.
Args:
x (jnp.array):
Inputs for the first layer, shape (batch, layer_dims[0]).
Returns:
x (jnp.array):
Network output, shape (batch, layer_dims[-1]).
Example:
>>> model = ActNet(layer_dims=[2, 5, 1], N=4, add_bias=True, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = model(x_batch)
"""
# Apply omega0 frequency scaling (Appendix D.1)
if self.omega0 != 1.0:
x = self.omega0 * x
# Input projection if enabled
if self.use_projections and hasattr(self, 'input_proj'):
x = self.input_proj(x)
# Pass through each ActLayer
for i, layer in enumerate(self.layers):
x = layer(x)
if self.add_bias and i < len(self.biases):
x += self.biases[i]
# Output projection if enabled
if self.use_projections and hasattr(self, 'output_proj'):
x = self.output_proj(x)
return x