import jax.numpy as jnp
from flax import nnx
from typing import Union, List
from ..layers import get_layer
from ..layers.Dense import DenseLayer
from ..layers.Chebyshev import Cb
from .utils import get_activation
[docs]
class ChebyshevEmbedding(nnx.Module):
"""
Chebyshev polynomial embedding layer with trainable coefficients.
For an input x, computes: [c_0 * T_0(x), c_1 * T_1(x), ..., c_{D_e} * T_{D_e}(x)]
where T_n are Chebyshev polynomials of the first kind and c_n are trainable parameters.
Attributes:
D_e (int):
Degree of Chebyshev polynomial expansion.
use_exact (bool):
Whether to use exact Chebyshev polynomials from Cb dictionary.
C (nnx.Param):
Trainable coefficients for the Chebyshev polynomials.
"""
[docs]
def __init__(self, D_e: int):
"""
Initializes a ChebyshevEmbedding layer.
Args:
D_e (int):
Degree of Chebyshev polynomial expansion.
Example:
>>> embedding = ChebyshevEmbedding(D_e=5)
"""
self.D_e = D_e
# Check if D_e exceeds the maximum degree in the Cb dictionary
self.use_exact = D_e <= max(Cb.keys())
# Initialize trainable coefficients C_n
# Motivated by code accompanying original paper: c_i = 1/(i+1) for i = 0, 1, ..., D_e
C_init = jnp.array([1.0 / (i + 1) for i in range(D_e + 1)], dtype=jnp.float32)
self.C = nnx.Param(C_init)
[docs]
def __call__(self, x):
"""
Applies Chebyshev embedding to input.
Args:
x (jnp.array):
Input tensor, shape (batch, n_features) or (batch,).
Returns:
embedded (jnp.array):
Chebyshev embedded tensor, shape (batch, n_features * (D_e + 1)).
Example:
>>> embedding = ChebyshevEmbedding(D_e=5)
>>> x = jax.random.uniform(jax.random.key(0), (100, 1))
>>> y = embedding(x) # shape: (100, 6)
"""
# Handle 1D input (batch,) -> (batch, 1)
if x.ndim == 1:
x = x[:, None]
batch = x.shape[0]
n_features = x.shape[1]
if self.use_exact:
# Use pre-defined Chebyshev polynomials from Cb dictionary
# Compute Chebyshev polynomials T_0(x) to T_{D_e}(x)
# Shape: (batch, n_features, D_e + 1)
cheb = jnp.stack([Cb[i](x) for i in range(self.D_e + 1)], axis=-1)
else:
# Use recursive formula for higher degrees
# T_0(x) = 1, T_1(x) = x, T_n(x) = 2*x*T_{n-1}(x) - T_{n-2}(x)
cheb = jnp.ones((batch, n_features, self.D_e + 1))
cheb = cheb.at[:, :, 1].set(x)
for K in range(2, self.D_e + 1):
cheb = cheb.at[:, :, K].set(2 * x * cheb[:, :, K - 1] - cheb[:, :, K - 2])
# Apply trainable coefficients: c_n * T_n(x)
# C shape: (D_e + 1,), broadcast to (batch, n_features, D_e + 1)
weighted_cheb = cheb * self.C
# Flatten the last two dimensions
# Shape: (batch, n_features * (D_e + 1))
embedded = weighted_cheb.reshape(batch, -1)
return embedded
[docs]
class InnerBlock(nnx.Module):
"""
Inner Block for KKAN architecture.
The Inner Block processes a single input dimension x_p through:
1. Chebyshev embedding: x_p -> [c_0*T_0(x_p), ..., c_{D_e}*T_{D_e}(x_p)] -> (D_e+1)-dim
2. Input Dense layer: (D_e+1)-dim -> H-dim
3. L hidden Dense layers: H-dim -> H-dim (each followed by activation)
4. Output Chebyshev embedding: H-dim -> H*(D_e+1)-dim (flattened)
5. Final Dense layer: H*(D_e+1)-dim -> m-dim
Attributes:
activation (callable):
Activation function.
input_embedding (ChebyshevEmbedding):
Chebyshev embedding layer for input.
input_layer (DenseLayer):
Dense layer after input embedding.
hidden_layers (nnx.List):
List of hidden Dense layers.
output_embedding (ChebyshevEmbedding):
Chebyshev embedding layer for output.
output_layer (Dense):
Final Dense layer.
"""
[docs]
def __init__(self,
D_e: int = 7,
H: int = 32,
L: int = 4,
m: int = 32,
activation: str = 'tanh',
seed: int = 42
):
"""
Initializes an InnerBlock.
Args:
D_e (int):
Degree of Chebyshev polynomial expansion.
H (int):
Hidden dimension for MLP layers.
L (int):
Number of hidden layers.
m (int):
Output dimension.
activation (str):
Activation function.
seed (int):
Random seed.
Example:
>>> inner_block = InnerBlock(D_e=5, H=32, L=2, m=10, activation='tanh', seed=42)
"""
self.activation = get_activation(activation)
# Input Chebyshev embedding
self.input_embedding = ChebyshevEmbedding(D_e=D_e)
# Input Dense layer: (D_e + 1) -> H
self.input_layer = DenseLayer(n_in=D_e + 1, n_out=H, seed=seed)
# Hidden Dense layers: H -> H
self.hidden_layers = nnx.List([
DenseLayer(n_in=H, n_out=H, seed=seed + i + 1)
for i in range(L)
])
# Output Chebyshev embedding (operates on H-dimensional vector)
self.output_embedding = ChebyshevEmbedding(D_e=D_e)
# Final Dense layer: H * (D_e + 1) -> m
self.output_layer = DenseLayer(n_in=H * (D_e + 1), n_out=m, seed=seed)
[docs]
def __call__(self, x_p):
"""
Forward pass through the inner block for a single input component.
Args:
x_p (jnp.array):
Single input component, shape (batch, 1).
Returns:
out (jnp.array):
Output tensor, shape (batch, m).
Example:
>>> inner_block = InnerBlock(D_e=5, H=32, L=2, m=10, seed=42)
>>> x_p = jax.random.uniform(jax.random.key(0), (100, 1))
>>> y = inner_block(x_p) # shape: (100, 10)
"""
# Step 1: Input Chebyshev embedding
# x_p: (batch, 1) -> (batch, D_e + 1)
h = self.input_embedding(x_p)
# Step 2: Input Dense layer + activation
# (batch, D_e + 1) -> (batch, H)
h = self.input_layer(h)
h = self.activation(h)
# Step 3: L hidden Dense layers + activations
# (batch, H) -> (batch, H)
for layer in self.hidden_layers:
h = layer(h)
h = self.activation(h)
# Step 4: Output Chebyshev embedding
# (batch, H) -> (batch, H * (D_e + 1))
h = self.output_embedding(h)
# Step 5: Final Dense layer (no activation)
# (batch, H * (D_e + 1)) -> (batch, m)
out = self.output_layer(h)
return out
[docs]
class OuterBlock(nnx.Module):
"""
Outer Block for KKAN architecture.
This is a wrapper around existing KAN layers from jaxkan.layers.
It applies the selected KAN layer to map from m dimensions to n_out dimensions.
Attributes:
layer (nnx.Module):
The underlying KAN layer instance.
"""
[docs]
def __init__(self,
m: int,
n_out: int,
layer_type: str = 'sine',
layer_params: Union[dict, None] = {'D': 7, 'init_scheme': {'type': 'glorot_fine'}},
seed: int = 42
):
"""
Initializes an OuterBlock.
Args:
m (int):
Input dimension.
n_out (int):
Output dimension.
layer_type (str):
Type of KAN layer ('chebyshev', 'legendre', 'rbf', 'sine', 'fourier', etc.).
layer_params (dict, optional):
Additional parameters for the KAN layer (e.g., D, flavor, kernel).
seed (int):
Random seed.
Example:
>>> outer_block = OuterBlock(m=10, n_out=1, layer_type='chebyshev',
... layer_params={'D': 5}, seed=42)
"""
# Get the layer class
LayerClass = get_layer(layer_type)
# Default layer parameters
if layer_params is None:
layer_params = {}
# Create the KAN layer
self.layer = LayerClass(
n_in=m,
n_out=n_out,
seed=seed,
**layer_params
)
[docs]
def __call__(self, xi):
"""
Forward pass through the outer block.
Args:
xi (jnp.array):
Input from combination layer, shape (batch, m).
Returns:
y (jnp.array):
Output tensor, shape (batch, n_out).
Example:
>>> outer_block = OuterBlock(m=10, n_out=1, layer_type='chebyshev', seed=42)
>>> xi = jax.random.uniform(jax.random.key(0), (100, 10))
>>> y = outer_block(xi) # shape: (100, 1)
"""
return self.layer(xi)
[docs]
class KKAN(nnx.Module):
"""
KKAN architecture based on:
"KKANs: Kůrková-Kolmogorov-Arnold Networks and Their Learning Dynamics"
by Juan Diego Toscano, Li-Lian Wang, and George Em Karniadakis
Attributes:
n_in (int):
Input dimension (d).
inner_blocks (nnx.List):
List of InnerBlock modules, one for each input dimension.
outer_block (OuterBlock):
The outer block that produces the final output.
"""
[docs]
def __init__(self,
n_in: int,
n_out: int,
m: int = 32,
D_e: int = 7,
H: int = 32,
L: int = 4,
activation: str = 'tanh',
outer_layer_type: str = 'sine',
outer_layer_params: Union[dict, None] = {'D': 7, 'init_scheme': {'type': 'glorot_fine'}},
seed: int = 42):
"""
Initializes a KKAN model.
Args:
n_in (int):
Input dimension.
n_out (int):
Output dimension.
m (int):
Intermediate dimension.
D_e (int):
Degree of Chebyshev expansion in inner blocks.
H (int):
Hidden dimension for inner block MLP.
L (int):
Number of hidden layers in inner block MLP.
activation (str):
Activation function ('tanh', 'relu', 'silu', 'gelu').
outer_layer_type (str):
Type of KAN layer for outer block ('chebyshev', 'legendre', 'rbf', 'sine', etc.).
outer_layer_params (dict, optional):
Additional parameters for outer block layer (e.g., {'D': 5, 'flavor': 'exact'}).
seed (int):
Random seed.
Example:
>>> model = KKAN(n_in=2, n_out=1, m=10, D_e=5, H=32, L=2,
... outer_layer_type='chebyshev', seed=42)
"""
self.n_in = n_in
# Create d inner blocks (one for each input dimension)
self.inner_blocks = nnx.List([
InnerBlock(
D_e=D_e,
H=H,
L=L,
m=m,
activation=activation,
seed=seed + p
)
for p in range(n_in)
])
# Outer block (KAN layer)
self.outer_block = OuterBlock(
m=m,
n_out=n_out,
layer_type=outer_layer_type,
layer_params=outer_layer_params,
seed=seed
)
[docs]
def __call__(self, x):
"""
Forward pass through the KKAN model.
Args:
x (jnp.array):
Input tensor, shape (batch, n_in).
Returns:
y (jnp.array):
Output tensor, shape (batch, n_out).
Example:
>>> model = KKAN(n_in=2, n_out=1, m=10, D_e=5, H=32, L=2, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = model(x_batch)
"""
# Part 1: Inner Blocks
# Process each input dimension through its corresponding inner block
psi_outputs = []
for p in range(self.n_in):
x_p = x[:, p:p+1] # (batch, 1)
psi_p = self.inner_blocks[p](x_p) # (batch, m)
psi_outputs.append(psi_p)
# Part 2: Combination Layer
# Sum the outputs from all inner blocks
# ξ_q = Σ_{p=1}^{d} Ψ_p,q(x_p)
xi = jnp.stack(psi_outputs, axis=0).sum(axis=0) # (batch, m)
# Part 3: Outer Block
# Apply KAN layer to produce final output
y = self.outer_block(xi) # (batch, n_out)
return y