import jax.numpy as jnp
from flax import nnx
from ..layers.Sine import SineLayer
from ..layers.Chebyshev import ChebyshevLayer as Layer
from .utils import PeriodEmbedder, RFFEmbedder
from typing import Union
[docs]
class RGABlock(nnx.Module):
"""
Residual-Gated Adaptive Block for RGAKAN architecture.
Attributes:
InputLayer (Layer):
First Chebyshev layer in the block.
OutputLayer (Layer):
Second Chebyshev layer in the block.
alpha (nnx.Param):
Trainable residual connection weight for the output.
beta (nnx.Param):
Trainable residual connection weight for the hidden state.
"""
[docs]
def __init__(self, n_in: int, n_out: int, n_hidden: int, D: int = 5, flavor: str = 'exact',
init_scheme: Union[dict, None] = None, alpha: float = 0.0, beta: float = 1.0,
seed: int = 42):
"""
Initializes an RGABlock.
Args:
n_in (int):
Input dimension.
n_out (int):
Output dimension.
n_hidden (int):
Hidden layer dimension.
D (int):
Degree of Chebyshev polynomials.
flavor (str):
Type of Chebyshev layer ('exact' or other variants).
init_scheme (dict, optional):
Initialization scheme for layer weights.
alpha (float):
Initial value for output residual connection weight.
beta (float):
Initial value for hidden residual connection weight.
seed (int):
Random seed for reproducible initialization.
Example:
>>> block = RGABlock(n_in=64, n_out=64, n_hidden=64, D=5, flavor='exact', seed=42)
"""
# Define the 2 layers
self.InputLayer = Layer(n_in = n_in, n_out = n_hidden, D = D, flavor = flavor,
residual = None, external_weights = False, init_scheme = init_scheme,
add_bias = True, seed = seed)
self.OutputLayer = Layer(n_in = n_hidden, n_out = n_out, D = D, flavor = flavor,
residual = None, external_weights = False, init_scheme = init_scheme,
add_bias = True, seed = seed)
# Define alpha, beta
self.alpha = nnx.Param(jnp.array(alpha, dtype=jnp.float32))
self.beta = nnx.Param(jnp.array(beta, dtype=jnp.float32))
[docs]
def __call__(self, x, u, v):
"""
Forward pass through the RGA block.
Args:
x (jnp.array):
Input array, shape (batch, n_in).
u (jnp.array):
First gating signal, shape (batch, n_hidden).
v (jnp.array):
Second gating signal, shape (batch, n_hidden).
Returns:
x (jnp.array):
Output after applying gated attention and residual connections, shape (batch, n_out).
Example:
>>> block = RGABlock(n_in=64, n_out=64, n_hidden=64, D=5, seed=42)
>>> x = jnp.ones((32, 64))
>>> u = jnp.ones((32, 64))
>>> v = jnp.zeros((32, 64))
>>> output = block(x, u, v) # Shape: (32, 64)
"""
identity = x
x = self.InputLayer(x)
x = x * u + (1 - x) * v
b = self.beta
x = b * x + (1 - b) * identity
x = self.OutputLayer(x)
x = x * u + (1 - x) * v
a = self.alpha
x = a * x + (1 - a) * identity
return x
[docs]
class RGAKAN(nnx.Module):
"""
Residual-Gated Adaptive Kolmogorov-Arnold Network (RGAKAN).
See paper "Training Deep Physics-Informed Kolmogorov-Arnold Networks".
https://www.sciencedirect.com/science/article/pii/S0045782526000356
Attributes:
pi_init (bool):
Whether physics-informed initialization is enabled.
n_hidden (int):
Hidden layer dimension.
D (int):
Degree of Chebyshev polynomials.
PE (Union[PeriodEmbedder, None]):
Periodic embedder if period_axes is provided.
FE (Union[RFFEmbedder, None]):
Random Fourier Features embedder if rff_std is provided.
SineBasis (Union[SineLayer, None]):
Sine basis layer if sine_D is provided.
U (Layer):
First gating network.
V (Layer):
Second gating network.
blocks (nnx.List):
List of RGABlock instances.
OutBasis (Union[nnx.Param, None]):
Physics-informed output coefficients if pi_init is True.
OutLayer (Union[Layer, None]):
Standard output layer if pi_init is False.
"""
[docs]
def __init__(self, n_in: int, n_out: int, n_hidden: int, num_blocks: int,
flavor: str = 'exact', D: int = 5, init_scheme: Union[dict, None] = None,
alpha: float = 0.0, beta: float = 1.0, ref: Union[None, dict] = None,
period_axes: Union[None, dict] = None, rff_std: Union[None, float] = None,
sine_D: Union[None, int] = None, seed: int = 42):
"""
Initializes an RGAKAN model.
Args:
n_in (int):
Input dimension (before any embeddings).
n_out (int):
Output dimension.
n_hidden (int):
Hidden layer dimension.
num_blocks (int):
Number of RGA blocks to stack.
flavor (str):
Type of Chebyshev layer ('exact' or other variants).
D (int):
Degree of Chebyshev polynomials.
init_scheme (dict, optional):
Initialization scheme for layer weights.
alpha (float):
Initial value for output residual connection weights in blocks.
beta (float):
Initial value for hidden residual connection weights in blocks.
ref (dict, optional):
Reference data for physics-informed initialization. Must contain 't', 'x', and 'usol'.
period_axes (dict, optional):
Dictionary for periodic embedding: {axis: (period, trainable)}.
rff_std (float, optional):
Standard deviation for Random Fourier Features embedding.
sine_D (int, optional):
Degree for sine basis layer.
seed (int):
Random seed for reproducible initialization.
Example:
>>> # Standard RGAKAN
>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, D=5, seed=42)
>>>
>>> # RGAKAN with periodic embedding
>>> period_axes = {0: (2.0 * jnp.pi, False)}
>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4,
... period_axes=period_axes, seed=42)
"""
self.pi_init = True if ref is not None else False
self.n_hidden = n_hidden
self.D = D
# Check for periodic embeddings
if period_axes:
self.PE = PeriodEmbedder(period_axes)
n_in += len(period_axes.keys()) # input dimension has now changed
else:
self.PE = None
# Check for RFF
if rff_std:
self.FE = RFFEmbedder(std = rff_std, n_in = n_in, embed_dim = n_hidden)
n_in = n_hidden # input dimension has now changed
else:
self.FE = None
# Check for Sine-Basis Layer
if sine_D:
self.SineBasis = SineLayer(n_in = n_in, n_out = n_hidden, D = sine_D, residual = None,
external_weights = False, init_scheme = init_scheme,
add_bias = True, seed = seed)
else:
self.SineBasis = None
# Define gates
self.U = Layer(n_in = n_hidden, n_out = n_hidden, D = D, flavor = flavor,
residual = None, external_weights = False, init_scheme = init_scheme,
add_bias = True, seed = seed)
self.V = Layer(n_in = n_hidden, n_out = n_hidden, D = D, flavor = flavor,
residual = None, external_weights = False, init_scheme = init_scheme,
add_bias = True, seed = seed)
# Define blocks
self.blocks = nnx.List([])
for i in range(num_blocks):
self.blocks.append(
RGABlock(
n_in = n_hidden, n_out = n_hidden, n_hidden = n_hidden, D = D, flavor = flavor,
init_scheme = init_scheme, alpha = alpha, beta = beta, seed = seed
)
)
# Check for physics-informed initialization
if self.pi_init:
C = self._pi_init(ref)
self.OutBasis = nnx.Param(jnp.array(C))
else:
self.OutLayer = Layer(n_in = n_hidden, n_out = n_out, flavor = flavor,
residual = None, external_weights = False, init_scheme = init_scheme,
add_bias = True, seed = seed)
def _pi_init(self, ref):
"""
Performs physics-informed initialization for the output layer.
Args:
ref (dict):
Reference data dictionary containing:
- 't' (jnp.array): Temporal coordinates.
- 'x' (jnp.array): Spatial coordinates.
- 'usol' (jnp.array): Solution array where usol[0, :] is the initial condition.
Returns:
C (jnp.array):
Output basis coefficients, shape (1, n_hidden, D).
Example:
>>> ref = {'t': t_array, 'x': x_array, 'usol': u_solution}
>>> C = model._pi_init(ref)
"""
# Get collocation points for the spatiotemporal domain to impose initial condition
t = ref['t'].flatten()[::10] # Downsampled temporal - shape (Nt, )
# Check if we have 3D data (t, x, y) or 2D data (t, x)
if 'y' in ref:
downsample = 10
x = ref['x'].flatten()[::downsample] # spatial - shape (Nx, )
y = ref['y'].flatten()[::downsample] # shape (Ny, )
tt, xx, yy = jnp.meshgrid(t, x, y, indexing="ij")
# collocation inputs - shape (batch, 3), batch = Nt*Nx*Ny
inputs = jnp.hstack([tt.flatten()[:, None], xx.flatten()[:, None], yy.flatten()[:, None]])
# Get Y for inputs - initial condition at t=0
u_0 = ref['usol'][0, ::downsample, ::downsample] # shape (Nx, Ny)
Y = jnp.tile(u_0.flatten(), (t.shape[0], 1)) # shape (Nt, Nx*Ny)
Y = Y.flatten().reshape(-1, 1) # shape (batch, 1)
else:
x = ref['x'].flatten() # spatial - shape (Nx, )
tt, xx = jnp.meshgrid(t, x, indexing="ij")
# collocation inputs - shape (batch, 2), batch = Nt*Nx
inputs = jnp.hstack([tt.flatten()[:, None], xx.flatten()[:, None]])
# Get Y for inputs
u_0 = ref['usol'][0, :] # initial condition - shape (Nx, )
Y = jnp.tile(u_0.flatten(), (t.shape[0], 1)) # shape (Nt, Nx)
Y = Y.flatten().reshape(-1, 1) # shape (batch, 1)
# Get Φ - essentially do a full forward pass up until the final layer
if self.PE:
inputs = self.PE(inputs)
if self.FE:
inputs = self.FE(inputs)
if self.SineBasis:
inputs = self.SineBasis(inputs)
u, v = self.U(inputs), self.V(inputs)
for block in self.blocks:
inputs = block(inputs, u, v)
Phi = self.U.basis(inputs) # (batch, n_hidden, D)
# Reshape to (batch, n_hidden * D)
Phi_flat = Phi.reshape(Phi.shape[0], -1)
# Solve least squares to get C as shape (n_hidden * D, 1)
result, residuals, rank, s = jnp.linalg.lstsq(
Phi_flat, Y, rcond=None
)
# result.T is shaped (1, n_hidden * D), so we reshape to (1, n_hidden, D)
C = result.T.reshape(1, self.n_hidden, self.D)
return C
[docs]
def __call__(self, x):
"""
Forward pass through the RGAKAN model.
Args:
x (jnp.array):
Input array, shape (batch, n_in).
Returns:
y (jnp.array):
Model output, shape (batch, n_out).
Example:
>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, seed=42)
>>> x = jnp.ones((32, 2))
>>> y = model(x) # Shape: (32, 1)
"""
# Apply embedders
if self.PE:
x = self.PE(x)
if self.FE:
x = self.FE(x)
if self.SineBasis:
x = self.SineBasis(x)
# Get u and v
u = self.U(x)
v = self.V(x)
# Pass through blocks
for block in self.blocks:
x = block(x, u, v)
# If the last layer is physics-informed
if self.pi_init:
C = self.OutBasis.value # (1, n_hidden, D)
# use u (or v) as a helper to apply basis on x - helper
B = self.U.basis(x) # (batch, n_hidden, D)
y = jnp.einsum('bhk, ohk -> bo', B, C) # (batch, 1)
else:
y = self.OutLayer(x)
return y