API Reference
KAN Layer Classes
- jaxkan.layers.__init__.get_layer(layer_type)[source]
Helper method that creates a mapping between layer type codes and the actual classes.
- Parameters:
layer_type (str) – Code of layer to be used.
- Returns:
A jaxkan.layers layer class instance to be used as the building block of a KAN.
- Return type:
layer (jaxkan.layers.Layer)
Example
>>> LayerClass = get_layer("base")
- class jaxkan.layers.Spline.BaseLayer(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]
Bases:
ModuleBaseLayer class. Corresponds to the original spline-based KAN Layer introduced in the original version of KAN. Ref: https://arxiv.org/abs/2404.19756
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- k
Order of the spline basis functions.
- Type:
int
- residual
Function that is applied on samples to calculate residual activation.
- Type:
Union[nnx.Module, None]
- rngs
Random number generator state.
- Type:
nnx.Rngs
- c_spl
Spline weights if external_weights is True, else None.
- Type:
Union[nnx.Param, None]
- c_basis
Trainable coefficients for the basis functions.
- Type:
nnx.Param
- c_res
Trainable coefficients for residual activation if residual is not None.
- Type:
Union[nnx.Param, None]
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- __init__(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]
Initializes a BaseLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
k (int) – Order of the spline basis functions.
G (int) – Number of grid intervals.
grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.
grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.
external_weights (bool) – Boolean that controls if the trainable weights of shape (n_out, n_in) applied to the splines should be used.
init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42)
- basis(x)[source]
Uses k and the current grid to calculate the values of spline basis functions on the input.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Spline basis functions applied on inputs, shape (n_in*n_out, G+k, batch).
- Return type:
basis_splines (jnp.array)
Example
>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer.basis(x_batch)
- update_grid(x, G_new)[source]
Performs a grid update given a new value for G (i.e., G_new) and adapts it to the given data, x. Additionally, re-initializes the c_basis parameters to a better estimate, based on the new grid.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
G_new (int) – Size of the new grid (in terms of intervals).
Example
>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> layer.update_grid(x=x_batch, G_new=5)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, corresponding to the weighted sum of the B-spline activation and the residual activation, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.Spline.SplineLayer(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]
Bases:
ModuleSplineLayer class. Corresponds to the “efficient” version of the spline-based KAN Layer. Ref: https://github.com/Blealtan/efficient-kan
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- k
Order of the spline basis functions.
- Type:
int
- residual
Function that is applied on samples to calculate residual activation.
- Type:
Union[nnx.Module, None]
- rngs
Random number generator state.
- Type:
nnx.Rngs
- grid
Grid object for spline basis functions.
- Type:
- c_spl
Spline weights if external_weights is True, else None.
- Type:
Union[nnx.Param, None]
- c_basis
Trainable coefficients for the basis functions.
- Type:
nnx.Param
- c_res
Trainable coefficients for residual activation if residual is not None.
- Type:
Union[nnx.Param, None]
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- __init__(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]
Initializes a BaseLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
k (int) – Order of the spline basis functions.
G (int) – Number of grid intervals.
grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.
grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.
external_weights (bool) – Boolean that controls if the trainable weights of shape (n_out, n_in) applied to the splines should be used.
init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42)
- basis(x)[source]
Uses k and the current grid to calculate the values of spline basis functions on the input.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Spline basis functions applied on inputs, shape (n_in*n_out, G+k, batch).
- Return type:
basis_splines (jnp.array)
Example
>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer.basis(x_batch)
- update_grid(x, G_new)[source]
Performs a grid update given a new value for G (i.e., G_new) and adapts it to the given data, x. Additionally, re-initializes the c_basis parameters to a better estimate, based on the new grid.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
G_new (int) – Size of the new grid (in terms of intervals).
Example
>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> layer.update_grid(x=x_batch, G_new=5)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, corresponding to the weighted sum of the B-spline activation and the residual activation, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3, >>> G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu, >>> external_weights = True, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.Chebyshev.ChebyshevLayer(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Bases:
Module- ChebyshevLayer class. Corresponds to the Chebyshev version of KANs and comes in three “flavors”:
“default”: the version presented in https://arxiv.org/pdf/2405.07200 “modified”: the version presented in https://www.sciencedirect.com/science/article/pii/S0045782524005462 “exact”: uses pre-defined functions for higher efficiency, but cannot scale up to arbitrary degrees
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- D
Degree of Chebyshev polynomial (1st kind).
- Type:
int
- flavor
One of “default”, “modified”, or “exact” - chooses basis implementation.
- Type:
Union[str, None]
- residual
Function that is applied on samples to calculate residual activation.
- Type:
Union[nnx.Module, None]
- rngs
Random number generator state.
- Type:
nnx.Rngs
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- c_ext
External weights if external_weights is True, else None.
- Type:
Union[nnx.Param, None]
- c_basis
Trainable coefficients for the basis functions.
- Type:
nnx.Param
- c_res
Trainable coefficients for residual activation if residual is not None.
- Type:
Union[nnx.Param, None]
- __init__(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Initializes a ChebyshevLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
D (int) – Degree of Chebyshev polynomial (1st kind).
flavor (Union[str, None]) – One of “default”, “modified”, or “exact” - chooses basis implementation.
residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.
external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.
init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42)
- basis(x)[source]
Based on the degree and flavor, the values of the Chebyshev basis functions are calculated on the input.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Chebyshev basis functions applied on inputs, shape (batch, n_in, D+1).
- Return type:
cheb (jnp.array)
Example
>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer.basis(x_batch)
- update_grid(x, D_new)[source]
For the case of ChebyKANs there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the degree of the polynomials.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
D_new (int) – New Chebyshev polynomial degree.
Example
>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> layer.update_grid(x=x_batch, D_new=8)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.Legendre.LegendreLayer(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Bases:
Module- LegendreLayer class. Corresponds to the Legendre version of KANs and comes in two “flavors”:
“default”: uses the recursion formula to calculate polynomials up to arbitrary degree “exact”: uses pre-defined functions for higher efficiency, but cannot scale up to arbitrary degrees
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- D
Degree of Legendre polynomial.
- Type:
int
- flavor
One of “default” or “exact” - chooses basis implementation.
- Type:
str
- residual
Function that is applied on samples to calculate residual activation.
- Type:
Union[nnx.Module, None]
- rngs
Random number generator state.
- Type:
nnx.Rngs
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- c_ext
External weights if external_weights is True, else None.
- Type:
Union[nnx.Param, None]
- c_basis
Trainable coefficients for the basis functions.
- Type:
nnx.Param
- c_res
Trainable coefficients for residual activation if residual is not None.
- Type:
Union[nnx.Param, None]
- __init__(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Initializes a LegendreLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
D (int) – Degree of Legendre polynomial.
flavor (Union[str, None]) – One of “default”, “modified”, or “exact” - chooses basis implementation.
residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.
external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.
init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42)
- basis(x)[source]
Based on the degree and flavor, the values of the Legendre basis functions are calculated on the input.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Legendre basis functions applied on inputs, shape (batch, n_in, D+1).
- Return type:
leg (jnp.array)
Example
>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer.basis(x_batch)
- update_grid(x, D_new)[source]
For the case of LegendreKANs there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the degree of the polynomials.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
D_new (int) – New Legendre polynomial degree.
Example
>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> layer.update_grid(x=x_batch, D_new=8)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default", >>> residual = None, external_weights = False, init_scheme = None, >>> add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.Fourier.FourierLayer(n_in=2, n_out=5, D=5, smooth_init=True, add_bias=True, seed=42)[source]
Bases:
ModuleFourierLayer class. Corresponds to the Fourier-based version of KANs (FourierKAN). Ref: https://github.com/GistNoesis/FourierKAN
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- D
Order of Fourier sum.
- Type:
int
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- c_cos
Trainable cosine coefficients.
- Type:
nnx.Param
- c_sin
Trainable sine coefficients.
- Type:
nnx.Param
- __init__(n_in=2, n_out=5, D=5, smooth_init=True, add_bias=True, seed=42)[source]
Initializes a FourierLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
D (int) – Order of Fourier sum.
smooth_init (bool) – Whether to initialize Fourier coefficients with smoothening.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42)
- basis(x)[source]
Calculates the con/sin activations on the input x.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Cosines, sines applied on inputs, shape (batch, n_in, D).
- Return type:
c, s (tuple)
Example
>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output_1, output_2 = layer.basis(x_batch)
- update_grid(x, D_new)[source]
For the case of FourierKAN there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the number of summands.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
D_new (int) – New value for the fourier sum’s order.
Example
>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> layer.update_grid(x=x_batch, D_new=8)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.RBF.RBFLayer(n_in=2, n_out=5, D=5, kernel=None, grid_range=(-2.0, 2.0), grid_e=1.0, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Bases:
ModuleRBFLayer class. Corresponds to the RBF version of KANs using different kernels.
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- D
Number of basis functions.
- Type:
int
- kernel
Kernel configuration for the RBFs.
- Type:
dict
- residual
Function that is applied on samples to calculate residual activation.
- Type:
Union[nnx.Module, None]
- rngs
Random number generator state.
- Type:
nnx.Rngs
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- c_ext
External weights if external_weights is True, else None.
- Type:
Union[nnx.Param, None]
- c_basis
Trainable coefficients for the basis functions.
- Type:
nnx.Param
- c_res
Trainable coefficients for residual activation if residual is not None.
- Type:
Union[nnx.Param, None]
- __init__(n_in=2, n_out=5, D=5, kernel=None, grid_range=(-2.0, 2.0), grid_e=1.0, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Initializes a RBFLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
D (int) – Number of basis functions.
kernel (Union[dict, None]) – Kernel to be used for the RBFs.
grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.
grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.
external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.
init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian", >>> grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None, >>> external_weights = False, init_scheme = None, add_bias = True, >>> seed = 42)
- basis(x)[source]
Based on the degree and kernel, the values of the radial basis functions are calculated on the input.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Radial basis functions applied on inputs, shape (batch, n_in, D).
- Return type:
rbf (jnp.array)
Example
>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian", >>> grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None, >>> external_weights = False, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-2.0, maxval=2.0) >>> >>> output = layer.basis(x_batch)
- update_grid(x, D_new)[source]
Performs a grid update given a new value for D (i.e., D_new) and adapts it to the given data, x. Additionally, re-initializes the c_basis parameters to a better estimate, based on the new grid.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
D_new (int) – New number of spline basis functions.
Example
>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian", >>> grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None, >>> external_weights = False, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-2.0, maxval=2.0) >>> >>> layer.update_grid(x=x_batch, D_new=8)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian", >>> grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None, >>> external_weights = False, init_scheme = None, add_bias = True, >>> seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-2.0, maxval=2.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.Sine.SineLayer(n_in=2, n_out=5, D=5, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Bases:
ModuleSineLayer class, inspired from the sine basis functions introduced in https://arxiv.org/pdf/2410.01990
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- D
Number of basis functions.
- Type:
int
- residual
Function that is applied on samples to calculate residual activation.
- Type:
Union[nnx.Module, None]
- rngs
Random number generator state.
- Type:
nnx.Rngs
- bias
Bias parameter if add_bias is True, else None.
- Type:
Union[nnx.Param, None]
- omega
Trainable frequency parameters.
- Type:
nnx.Param
- phase
Trainable phase parameters.
- Type:
nnx.Param
- c_ext
External weights if external_weights is True, else None.
- Type:
Union[nnx.Param, None]
- c_basis
Trainable coefficients for the basis functions.
- Type:
nnx.Param
- c_res
Trainable coefficients for residual activation if residual is not None.
- Type:
Union[nnx.Param, None]
- __init__(n_in=2, n_out=5, D=5, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]
Initializes a SineLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
D (int) – Number of basis functions.
residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.
external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.
init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.
add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False, >>> init_scheme = None, add_bias = True, seed = 42)
- basis(x)[source]
Calculates the application of the sine basis functions on the input.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Sine basis functions applied on inputs, shape (batch, n_in, D).
- Return type:
B (jnp.array)
Example
>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False, >>> init_scheme = None, add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer.basis(x_batch)
- update_grid(x, D_new)[source]
For the case of sine-based KANs there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the number of basis functions, and by extension phases and omegas.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
D_new (int) – New number of basis functions.
Example
>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False, >>> init_scheme = None, add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> layer.update_grid(x=x_batch, D_new=8)
- __call__(x)[source]
The layer’s forward pass.
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False, >>> init_scheme = None, add_bias = True, seed = 42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.layers.Dense.DenseLayer(n_in, n_out, activation=None, RWF={'mean': 1.0, 'std': 0.1}, add_bias=True, seed=42)[source]
Bases:
ModuleDense layer with random weight factorization (RWF) for use in MLP architectures.
Note: This is not a KAN layer, but a standard MLP building block used in advanced KAN architectures like KKAN (see jaxkan.models module).
- g
Scale factor vector of shape (n_out,) from the RWF reparameterization.
- Type:
nnx.Param
- v
Direction matrix of shape (n_in, n_out) from the RWF reparameterization.
- Type:
nnx.Param
- b
Bias vector of shape (n_out,), or None if add_bias is False.
- Type:
nnx.Param or None
- activation
Activation function applied after the linear transformation, or None.
- Type:
callable or None
- __init__(n_in, n_out, activation=None, RWF={'mean': 1.0, 'std': 0.1}, add_bias=True, seed=42)[source]
Initializes a Dense layer with RWF.
- Parameters:
n_in (int) – Number of input features.
n_out (int) – Number of output features.
activation (callable, optional) – Activation function applied after the linear transformation. Defaults to None.
RWF (dict, optional) – Dictionary with keys
'mean'and'std'controlling the log-normal scale of the RWF reparameterization. Defaults to{"mean": 1.0, "std": 0.1}.add_bias (bool, optional) – Whether to include a learnable bias term. Defaults to True.
seed (int, optional) – Random seed for parameter initialization. Defaults to 42.
Example
>>> layer = DenseLayer(n_in=64, n_out=32, add_bias=True, seed=42)
- __call__(x)[source]
Applies the dense layer to the input.
- Parameters:
x (jnp.ndarray) – Input array of shape (batch, n_in).
- Returns:
Output array of shape (batch, n_out).
- Return type:
jnp.ndarray
Example
>>> layer = DenseLayer(n_in=4, n_out=2) >>> x = jnp.ones((3, 4)) >>> y = layer(x) # shape: (3, 2)
- jaxkan.layers.utils.solve_single_lstsq(A_single, B_single)[source]
Simulates linalg.lstsq by reformulating the problem AX = B via the normal equations: (A^T A) X = A^T B. This is used instead of linalg.lstsq because it’s much faster.
- Parameters:
A_single (jnp.array) – Matrix A of AX = B, shape (M, N).
B_single (jnp.array) – Matrix B of AX = B, shape (M, K).
- Returns:
Matrix X of AX = B, shape (N, K).
- Return type:
single_solution (jnp.array)
Example
>>> A = jnp.array([[2.0, 1.0], [1.0, 3.0]]) >>> B = jnp.array([[1.0], [2.0]]) >>> >>> solution = solve_single_lstsq(A, B)
- jaxkan.layers.utils.solve_full_lstsq(A_full, B_full)[source]
Parallelizes the single case, so that the problem can be solved for matrices with dimensions higher than 2. Essentially, solve_single_lstsq and solve_full_lstsq combined are a workaround, because (unlike PyTorch for example), JAX’s libraries do not support lstsq for dims > 2.
- Parameters:
A_full (jnp.array) – Matrix A of AX = B, shape (batch, M, N).
B_full (jnp.array) – Matrix B of AX = B, shape (batch, M, K).
- Returns:
Matrix X of AX = B, shape (batch, N, K).
- Return type:
full_solution (jnp.array)
Example
>>> A = jnp.array([[[2.0, 1.0], [1.0, 3.0]], [[1.0, 2.0], [2.0, 1.0]]]) >>> B = jnp.array([[[1.0], [2.0]], [[2.0], [3.0]]]) >>> >>> solution = solve_full_lstsq(A, B)
- jaxkan.layers.utils.interpolate_moments(mu_old, nu_old, new_shape)[source]
Performs a linear interpolation to assign values to the first and second-order moments of gradients of the c_i basis functions coefficients after grid extension.
- Parameters:
mu_old (jnp.array) – First-order moments before extension, shape (n_in*n_out, num_basis) or (n_out, n_in, num_basis).
nu_old (jnp.array) – Second-order moments before extension, shape (n_in*n_out, num_basis) or (n_out, n_in, num_basis).
new_shape (tuple) – The new desired shape, either (n_in*n_out, new_num_basis) or (n_out, n_in, new_num_basis).
- Returns:
First- and second-order moments after extension, shape new_shape.
- Return type:
mu_new, nu_new (tuple)
Example
>>> mu_old = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> nu_old = jnp.array([[7, 8, 9], [10, 11, 12]]) >>> new_shape = (2, 5) >>> >>> mu_new, nu_new = interpolate_moments(mu_old, nu_old, new_shape)
- jaxkan.layers.utils.adam_transition(old_state, model_state)[source]
Performs the state transition for the Adam optimizer with scheduler after grid extension. Note that the transition happens in-place, i.e. nothing is returned, the optimizer is simply transitioned from the old state to the new.
- Parameters:
old_state (tuple) – Collection of Adam state and scheduler state before extension.
model_state (dict) – Dict of KAN model state after split.
Example
>>> old_state = optimizer.opt_state >>> _, model_state = nnx.split(model) >>> adam_transition(old_state, model_state)
KAN Grid Classes
- class jaxkan.grids.BaseGrid.BaseGrid(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]
Bases:
objectBaseGrid class, corresponding to the grid of the BaseLayer class. It comprises an initialization as well as an update procedure.
- n_in
Number of layer’s incoming nodes.
- Type:
int
- n_out
Number of layer’s outgoing nodes.
- Type:
int
- k
Order of the spline basis functions.
- Type:
int
- G
Number of grid intervals.
- Type:
int
- grid_range
An initial range for the grid’s ends, although adaptivity can completely change it.
- Type:
tuple
- grid_e
Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
- Type:
float
- item
The actual grid array, shape (n_in*n_out, G + 2k + 1).
- Type:
jnp.array
- __init__(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]
Initializes a BaseGrid instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
k (int) – Order of the spline basis functions.
G (int) – Number of grid intervals.
grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.
grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
Example
>>> grid_type = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05) >>> grid = grid_type.item
- update(x, G_new)[source]
Update the grid based on input data and new grid size.
- Parameters:
x (jnp.ndarray) – Input data, shape (batch, n_in).
G_new (int) – New grid size in terms of intervals.
Example
>>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> grid = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05) >>> grid.update(x=x_batch, G_new=5)
- class jaxkan.grids.SplineGrid.SplineGrid(n_nodes=2, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]
Bases:
objectSplineGrid class, corresponding to the grid of the SplineLayer class. It comprises an initialization as well as an update procedure.
- n_nodes
Number of layer nodes.
- Type:
int
- k
Order of the spline basis functions.
- Type:
int
- G
Number of grid intervals.
- Type:
int
- grid_range
An initial range for the grid’s ends, although adaptivity can completely change it.
- Type:
tuple
- grid_e
Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
- Type:
float
- item
The actual grid array, shape (n_nodes, G + 2k + 1).
- Type:
jnp.array
- __init__(n_nodes=2, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]
Initializes a SplineGrid instance.
- Parameters:
n_nodes (int) – Number of layer nodes.
k (int) – Order of the spline basis functions.
G (int) – Number of grid intervals.
grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.
grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
Example
>>> grid_type = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05) >>> grid = grid_type.item
- update(x, G_new)[source]
Update the grid based on input data and new grid size.
- Parameters:
x (jnp.ndarray) – Input data, shape (batch, n_nodes).
G_new (int) – New grid size in terms of intervals.
Example
>>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> grid = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05) >>> grid.update(x=x_batch, G_new=5)
- class jaxkan.grids.RBFGrid.RBFGrid(n_nodes=2, D=3, grid_range=(-2.0, 2.0), grid_e=1.0)[source]
Bases:
objectRBFGrid class, corresponding to the grid of the RBFLayer class. It comprises an initialization as well as an update procedure.
- n_nodes
Number of layer nodes.
- Type:
int
- D
Number of radial basis functions.
- Type:
int
- grid_range
The range of the grid’s ends, on which the basis functions are defined.
- Type:
tuple
- grid_e
Parameter that defines if the grid is uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
- Type:
float
- item
The actual grid array, shape (n_nodes, D).
- Type:
jnp.array
- __init__(n_nodes=2, D=3, grid_range=(-2.0, 2.0), grid_e=1.0)[source]
Initializes a RBFGrid instance.
- Parameters:
n_nodes (int) – Number of layer nodes.
D (int) – Number of radial basis functions.
grid_range (tuple) – The range of the grid’s ends, on which the basis functions are defined.
grid_e (float) – Parameter that defines if the grid is uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
Example
>>> grid_type = RBFGrid(n_nodes = 2, D = 4, grid_range = (-2.0, 2.0), grid_e = 1.0) >>> grid = grid_type.item
- update(x, D_new)[source]
Update the grid based on input data and new grid size.
- Parameters:
x (jnp.ndarray) – Input data, shape (batch, n_nodes).
D_new (int) – New number of basis functions.
Example
>>> key = jax.random.key(42) >>> grid_range = (-2.0, 2.0) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=grid_range[0], maxval=grid_range[-1]) >>> >>> grid = RBFGrid(n_nodes = 2, D = 4, grid_range = grid_range, grid_e = 1.0) >>> grid.update(x=x_batch, D_new=8)
KAN Models
- class jaxkan.models.KAN.KAN(layer_dims, layer_type='base', required_parameters=None, seed=42)[source]
Bases:
ModuleKAN class, corresponding to a network of KAN Layers.
- layers
List of KAN layer instances.
- Type:
nnx.List
- __init__(layer_dims, layer_type='base', required_parameters=None, seed=42)[source]
Initializes a KAN model.
- Parameters:
layer_dims (List[int]) – Defines the network in terms of nodes. E.g. [4,5,1] is a network with 2 layers: one with n_in=4 and n_out=5 and one with n_in=5 and n_out = 1.
layer_type (str) – Type of layer to use (e.g., ‘base’).
required_parameters (dict) – Dictionary containing parameters required for the chosen layer type.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> req_params = {'k': 3, 'G': 5} >>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42)
- update_grids(x, G_new)[source]
Performs the grid update for each layer of the KAN architecture.
- Parameters:
x (jnp.array) – Inputs for the first layer, shape (batch, self.layers[0]).
G_new (int) – Size of the new grid (in terms of intervals).
Example
>>> req_params = {'k': 3, 'G': 5} >>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> model.update_grids(x=x_batch, G_new=10)
- __call__(x)[source]
Equivalent to the network’s forward pass.
- Parameters:
x (jnp.array) – Inputs for the first layer, shape (batch, self.layers[0]).
- Returns:
Network output, shape (batch, self.layers[-1]).
- Return type:
x (jnp.array)
Example
>>> req_params = {'k': 3, 'G': 5} >>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = model(x_batch)
- class jaxkan.models.RGAKAN.RGABlock(n_in, n_out, n_hidden, D=5, flavor='exact', init_scheme=None, alpha=0.0, beta=1.0, seed=42)[source]
Bases:
ModuleResidual-Gated Adaptive Block for RGAKAN architecture.
- InputLayer
First Chebyshev layer in the block.
- Type:
Layer
- OutputLayer
Second Chebyshev layer in the block.
- Type:
Layer
- alpha
Trainable residual connection weight for the output.
- Type:
nnx.Param
- beta
Trainable residual connection weight for the hidden state.
- Type:
nnx.Param
- __init__(n_in, n_out, n_hidden, D=5, flavor='exact', init_scheme=None, alpha=0.0, beta=1.0, seed=42)[source]
Initializes an RGABlock.
- Parameters:
n_in (int) – Input dimension.
n_out (int) – Output dimension.
n_hidden (int) – Hidden layer dimension.
D (int) – Degree of Chebyshev polynomials.
flavor (str) – Type of Chebyshev layer (‘exact’ or other variants).
init_scheme (dict, optional) – Initialization scheme for layer weights.
alpha (float) – Initial value for output residual connection weight.
beta (float) – Initial value for hidden residual connection weight.
seed (int) – Random seed for reproducible initialization.
Example
>>> block = RGABlock(n_in=64, n_out=64, n_hidden=64, D=5, flavor='exact', seed=42)
- __call__(x, u, v)[source]
Forward pass through the RGA block.
- Parameters:
x (jnp.array) – Input array, shape (batch, n_in).
u (jnp.array) – First gating signal, shape (batch, n_hidden).
v (jnp.array) – Second gating signal, shape (batch, n_hidden).
- Returns:
Output after applying gated attention and residual connections, shape (batch, n_out).
- Return type:
x (jnp.array)
Example
>>> block = RGABlock(n_in=64, n_out=64, n_hidden=64, D=5, seed=42) >>> x = jnp.ones((32, 64)) >>> u = jnp.ones((32, 64)) >>> v = jnp.zeros((32, 64)) >>> output = block(x, u, v) # Shape: (32, 64)
- class jaxkan.models.RGAKAN.RGAKAN(n_in, n_out, n_hidden, num_blocks, flavor='exact', D=5, init_scheme=None, alpha=0.0, beta=1.0, ref=None, period_axes=None, rff_std=None, sine_D=None, seed=42)[source]
Bases:
ModuleResidual-Gated Adaptive Kolmogorov-Arnold Network (RGAKAN). See paper “Training Deep Physics-Informed Kolmogorov-Arnold Networks”. https://www.sciencedirect.com/science/article/pii/S0045782526000356
- pi_init
Whether physics-informed initialization is enabled.
- Type:
bool
Hidden layer dimension.
- Type:
int
- D
Degree of Chebyshev polynomials.
- Type:
int
- PE
Periodic embedder if period_axes is provided.
- Type:
Union[PeriodEmbedder, None]
- FE
Random Fourier Features embedder if rff_std is provided.
- Type:
Union[RFFEmbedder, None]
- U
First gating network.
- Type:
Layer
- V
Second gating network.
- Type:
Layer
- blocks
List of RGABlock instances.
- Type:
nnx.List
- OutBasis
Physics-informed output coefficients if pi_init is True.
- Type:
Union[nnx.Param, None]
- OutLayer
Standard output layer if pi_init is False.
- Type:
Union[Layer, None]
- __init__(n_in, n_out, n_hidden, num_blocks, flavor='exact', D=5, init_scheme=None, alpha=0.0, beta=1.0, ref=None, period_axes=None, rff_std=None, sine_D=None, seed=42)[source]
Initializes an RGAKAN model.
- Parameters:
n_in (int) – Input dimension (before any embeddings).
n_out (int) – Output dimension.
n_hidden (int) – Hidden layer dimension.
num_blocks (int) – Number of RGA blocks to stack.
flavor (str) – Type of Chebyshev layer (‘exact’ or other variants).
D (int) – Degree of Chebyshev polynomials.
init_scheme (dict, optional) – Initialization scheme for layer weights.
alpha (float) – Initial value for output residual connection weights in blocks.
beta (float) – Initial value for hidden residual connection weights in blocks.
ref (dict, optional) – Reference data for physics-informed initialization. Must contain ‘t’, ‘x’, and ‘usol’.
period_axes (dict, optional) – Dictionary for periodic embedding: {axis: (period, trainable)}.
rff_std (float, optional) – Standard deviation for Random Fourier Features embedding.
sine_D (int, optional) – Degree for sine basis layer.
seed (int) – Random seed for reproducible initialization.
Example
>>> # Standard RGAKAN >>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, D=5, seed=42) >>> >>> # RGAKAN with periodic embedding >>> period_axes = {0: (2.0 * jnp.pi, False)} >>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, ... period_axes=period_axes, seed=42)
- __call__(x)[source]
Forward pass through the RGAKAN model.
- Parameters:
x (jnp.array) – Input array, shape (batch, n_in).
- Returns:
Model output, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, seed=42) >>> x = jnp.ones((32, 2)) >>> y = model(x) # Shape: (32, 1)
- class jaxkan.models.ActNet.ActLayer(n_in=3, n_out=4, N=5, train_basis=True, seed=42)[source]
Bases:
ModuleActLayer implementation based on: “Deep Learning Alternatives of the Kolmogorov Superposition Theorem” by Leonardo Ferreira Guilhoto and Paris Perdikaris (arXiv:2410.01990)
Forward pass: ActLayer(x) = S(Λ ⊙ (β @ B(x)))
Where: - B(x) is the basis expansion matrix with B(x)_{ij} = b_i(x_j), shape (N, d) - β ∈ R^{m × N} are the basis coefficients - Λ ∈ R^{m × d} are the mixing weights - S is the row-sum function - ⊙ is the Hadamard (element-wise) product
The k-th output is: (ActLayer(x))_k = Σ_i λ_{ki} Σ_j β_{kj} b_j(x_i)
- rngs
Random number generator state.
- Type:
nnx.Rngs
- beta
Basis coefficients, shape (n_out, N).
- Type:
nnx.Param
- Lambda
Mixing weights, shape (n_out, n_in).
- Type:
nnx.Param
- omega
Frequency parameters for basis functions (trainable or fixed).
- Type:
Union[nnx.Param, jnp.array]
- phase
Phase parameters for basis functions (trainable or fixed).
- Type:
Union[nnx.Param, jnp.array]
- __init__(n_in=3, n_out=4, N=5, train_basis=True, seed=42)[source]
Initializes an ActLayer instance.
- Parameters:
n_in (int) – Number of layer’s incoming nodes.
n_out (int) – Number of layer’s outgoing nodes.
N (int) – Number of basis functions (paper recommends N=4).
train_basis (bool) – Whether the basis function parameters (omega and phase) are trainable.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> layer = ActLayer(n_in=2, n_out=5, N=4, train_basis=True, seed=42)
- basis(x)[source]
Compute normalized sinusoidal basis functions.
Paper Equation 11: b_i(t) = (sin(ω_i * t + p_i) - μ(ω_i, p_i)) / σ(ω_i, p_i)
Where μ and σ are computed assuming x ~ N(0, 1): μ(ω, p) = exp(-ω²/2) * sin(p) σ(ω, p) = sqrt(1/2 - 1/2 * exp(-2ω²) * cos(2p) - μ²)
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Basis expansion matrix, shape (batch, N, n_in), where B[b, j, i] = b_j(x_{b,i}).
- Return type:
B (jnp.array)
Example
>>> layer = ActLayer(n_in=2, n_out=5, N=4, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> B = layer.basis(x_batch)
- __call__(x)[source]
The layer’s forward pass.
Paper Equation 6: ActLayer(x) = S(Λ ⊙ (β @ B(x))) Paper Equation 9: (ActLayer(x))_k = Σ_i λ_{ki} Σ_j β_{kj} b_j(x_i)
- Parameters:
x (jnp.array) – Inputs, shape (batch, n_in).
- Returns:
Output of the forward pass, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> layer = ActLayer(n_in=2, n_out=5, N=4, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = layer(x_batch)
- class jaxkan.models.ActNet.ActNet(layer_dims, N=4, add_bias=True, omega0=1.0, use_projections=False, train_basis=True, seed=42)[source]
Bases:
ModuleActNet architecture based on: “Deep Learning Alternatives of the Kolmogorov Superposition Theorem” by Leonardo Ferreira Guilhoto and Paris Perdikaris (arXiv:2410.01990)
- rngs
Random number generator state.
- Type:
nnx.Rngs
- add_bias
Whether to add learnable bias after each ActLayer.
- Type:
bool
- omega0
Frequency multiplier for input.
- Type:
float
- use_projections
Whether to use input/output linear projections.
- Type:
bool
- input_proj
Input projection layer if use_projections is True.
- Type:
nnx.Linear, optional
- layers
List of ActLayer instances.
- Type:
nnx.List
- output_proj
Output projection layer if use_projections is True.
- Type:
nnx.Linear, optional
- biases
List of bias parameters if add_bias is True.
- Type:
nnx.List, optional
- __init__(layer_dims, N=4, add_bias=True, omega0=1.0, use_projections=False, train_basis=True, seed=42)[source]
Initializes an ActNet model.
- Parameters:
layer_dims (List[int]) – Defines the network in terms of nodes. E.g. [2,5,1] is a network with 2 layers.
N (int) – Number of basis functions per ActLayer (paper recommends N=4).
add_bias (bool) – Whether to add learnable bias after each ActLayer.
omega0 (float) – Frequency multiplier for input (paper’s Appendix D.1).
use_projections (bool) – Whether to use input/output linear projections.
train_basis (bool) – Whether the basis function parameters (omega and phase) are trainable.
seed (int) – Random key selection for initializations wherever necessary.
Example
>>> model = ActNet(layer_dims=[2, 5, 1], N=4, add_bias=True, train_basis=True, seed=42)
- __call__(x)[source]
Equivalent to the network’s forward pass.
- Parameters:
x (jnp.array) – Inputs for the first layer, shape (batch, layer_dims[0]).
- Returns:
Network output, shape (batch, layer_dims[-1]).
- Return type:
x (jnp.array)
Example
>>> model = ActNet(layer_dims=[2, 5, 1], N=4, add_bias=True, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = model(x_batch)
- class jaxkan.models.KKAN.ChebyshevEmbedding(D_e)[source]
Bases:
ModuleChebyshev polynomial embedding layer with trainable coefficients.
For an input x, computes: [c_0 * T_0(x), c_1 * T_1(x), …, c_{D_e} * T_{D_e}(x)] where T_n are Chebyshev polynomials of the first kind and c_n are trainable parameters.
- D_e
Degree of Chebyshev polynomial expansion.
- Type:
int
- use_exact
Whether to use exact Chebyshev polynomials from Cb dictionary.
- Type:
bool
- C
Trainable coefficients for the Chebyshev polynomials.
- Type:
nnx.Param
- __init__(D_e)[source]
Initializes a ChebyshevEmbedding layer.
- Parameters:
D_e (int) – Degree of Chebyshev polynomial expansion.
Example
>>> embedding = ChebyshevEmbedding(D_e=5)
- __call__(x)[source]
Applies Chebyshev embedding to input.
- Parameters:
x (jnp.array) – Input tensor, shape (batch, n_features) or (batch,).
- Returns:
Chebyshev embedded tensor, shape (batch, n_features * (D_e + 1)).
- Return type:
embedded (jnp.array)
Example
>>> embedding = ChebyshevEmbedding(D_e=5) >>> x = jax.random.uniform(jax.random.key(0), (100, 1)) >>> y = embedding(x) # shape: (100, 6)
- class jaxkan.models.KKAN.InnerBlock(D_e=7, H=32, L=4, m=32, activation='tanh', seed=42)[source]
Bases:
ModuleInner Block for KKAN architecture.
The Inner Block processes a single input dimension x_p through: 1. Chebyshev embedding: x_p -> [c_0*T_0(x_p), …, c_{D_e}*T_{D_e}(x_p)] -> (D_e+1)-dim 2. Input Dense layer: (D_e+1)-dim -> H-dim 3. L hidden Dense layers: H-dim -> H-dim (each followed by activation) 4. Output Chebyshev embedding: H-dim -> H*(D_e+1)-dim (flattened) 5. Final Dense layer: H*(D_e+1)-dim -> m-dim
- activation
Activation function.
- Type:
callable
- input_embedding
Chebyshev embedding layer for input.
- Type:
- input_layer
Dense layer after input embedding.
- Type:
List of hidden Dense layers.
- Type:
nnx.List
- output_embedding
Chebyshev embedding layer for output.
- Type:
- output_layer
Final Dense layer.
- Type:
Dense
- __init__(D_e=7, H=32, L=4, m=32, activation='tanh', seed=42)[source]
Initializes an InnerBlock.
- Parameters:
D_e (int) – Degree of Chebyshev polynomial expansion.
H (int) – Hidden dimension for MLP layers.
L (int) – Number of hidden layers.
m (int) – Output dimension.
activation (str) – Activation function.
seed (int) – Random seed.
Example
>>> inner_block = InnerBlock(D_e=5, H=32, L=2, m=10, activation='tanh', seed=42)
- __call__(x_p)[source]
Forward pass through the inner block for a single input component.
- Parameters:
x_p (jnp.array) – Single input component, shape (batch, 1).
- Returns:
Output tensor, shape (batch, m).
- Return type:
out (jnp.array)
Example
>>> inner_block = InnerBlock(D_e=5, H=32, L=2, m=10, seed=42) >>> x_p = jax.random.uniform(jax.random.key(0), (100, 1)) >>> y = inner_block(x_p) # shape: (100, 10)
- class jaxkan.models.KKAN.OuterBlock(m, n_out, layer_type='sine', layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]
Bases:
ModuleOuter Block for KKAN architecture.
This is a wrapper around existing KAN layers from jaxkan.layers. It applies the selected KAN layer to map from m dimensions to n_out dimensions.
- layer
The underlying KAN layer instance.
- Type:
nnx.Module
- __init__(m, n_out, layer_type='sine', layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]
Initializes an OuterBlock.
- Parameters:
m (int) – Input dimension.
n_out (int) – Output dimension.
layer_type (str) – Type of KAN layer (‘chebyshev’, ‘legendre’, ‘rbf’, ‘sine’, ‘fourier’, etc.).
layer_params (dict, optional) – Additional parameters for the KAN layer (e.g., D, flavor, kernel).
seed (int) – Random seed.
Example
>>> outer_block = OuterBlock(m=10, n_out=1, layer_type='chebyshev', ... layer_params={'D': 5}, seed=42)
- __call__(xi)[source]
Forward pass through the outer block.
- Parameters:
xi (jnp.array) – Input from combination layer, shape (batch, m).
- Returns:
Output tensor, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> outer_block = OuterBlock(m=10, n_out=1, layer_type='chebyshev', seed=42) >>> xi = jax.random.uniform(jax.random.key(0), (100, 10)) >>> y = outer_block(xi) # shape: (100, 1)
- class jaxkan.models.KKAN.KKAN(n_in, n_out, m=32, D_e=7, H=32, L=4, activation='tanh', outer_layer_type='sine', outer_layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]
Bases:
ModuleKKAN architecture based on: “KKANs: Kůrková-Kolmogorov-Arnold Networks and Their Learning Dynamics” by Juan Diego Toscano, Li-Lian Wang, and George Em Karniadakis
- n_in
Input dimension (d).
- Type:
int
- inner_blocks
List of InnerBlock modules, one for each input dimension.
- Type:
nnx.List
- outer_block
The outer block that produces the final output.
- Type:
- __init__(n_in, n_out, m=32, D_e=7, H=32, L=4, activation='tanh', outer_layer_type='sine', outer_layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]
Initializes a KKAN model.
- Parameters:
n_in (int) – Input dimension.
n_out (int) – Output dimension.
m (int) – Intermediate dimension.
D_e (int) – Degree of Chebyshev expansion in inner blocks.
H (int) – Hidden dimension for inner block MLP.
L (int) – Number of hidden layers in inner block MLP.
activation (str) – Activation function (‘tanh’, ‘relu’, ‘silu’, ‘gelu’).
outer_layer_type (str) – Type of KAN layer for outer block (‘chebyshev’, ‘legendre’, ‘rbf’, ‘sine’, etc.).
outer_layer_params (dict, optional) – Additional parameters for outer block layer (e.g., {‘D’: 5, ‘flavor’: ‘exact’}).
seed (int) – Random seed.
Example
>>> model = KKAN(n_in=2, n_out=1, m=10, D_e=5, H=32, L=2, ... outer_layer_type='chebyshev', seed=42)
- __call__(x)[source]
Forward pass through the KKAN model.
- Parameters:
x (jnp.array) – Input tensor, shape (batch, n_in).
- Returns:
Output tensor, shape (batch, n_out).
- Return type:
y (jnp.array)
Example
>>> model = KKAN(n_in=2, n_out=1, m=10, D_e=5, H=32, L=2, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = model(x_batch)
- jaxkan.models.utils.get_activation(activation='tanh')[source]
Returns the corresponding activation function based on user input.
- Parameters:
activation (str) – Name of the activation function. Options include: - ‘celu’: Continuously Differentiable ELU - ‘elu’: Exponential Linear Unit - ‘gelu’: Gaussian Error Linear Unit - ‘hard_sigmoid’: Hard sigmoid - ‘hard_silu’ / ‘hard_swish’: Hard SiLU - ‘hard_tanh’: Hard hyperbolic tangent - ‘identity’: Identity function (no activation) - ‘leaky_relu’: Leaky ReLU - ‘log_sigmoid’: Log-sigmoid function - ‘relu’: Rectified Linear Unit - ‘selu’: Scaled ELU - ‘sigmoid’: Sigmoid function - ‘silu’ / ‘swish’: Sigmoid Linear Unit - ‘soft_sign’: Soft sign function - ‘softplus’: Softplus function - ‘tanh’: Hyperbolic tangent (default)
- Returns:
The activation function.
- Return type:
callable
Example
>>> act_fn = get_activation('tanh') >>> y = act_fn(x)
- class jaxkan.models.utils.PeriodEmbedder(period_axes)[source]
Bases:
ModulePeriodic embedding module that applies trigonometric transformations to specified input axes.
- axes
Dictionary storing period values for each axis. Values can be trainable (nnx.Param) or fixed.
- Type:
nnx.Dict
- __init__(period_axes)[source]
Initializes a PeriodEmbedder module.
- Parameters:
period_axes (dict) – Dictionary mapping input axis indices to (period, trainable) tuples. The key is the axis index (int), and the value is a tuple where: - period (float): The period value for the trigonometric transformation. - trainable (bool): If True, period is stored as nnx.Param and can be optimized during training.
Example
>>> # Fixed period on axis 0, trainable period on axis 1 >>> period_axes = {0: (2.0 * jnp.pi, False), 1: (jnp.pi, True)} >>> embedder = PeriodEmbedder(period_axes)
- __call__(x)[source]
Applies periodic embedding to the input.
- Parameters:
x (jnp.array) – Input array, shape (batch, n_in).
- Returns:
Embedded output. For each axis with a period, the original feature is replaced by cos(period * x) and sin(period * x). Non-periodic axes are passed through unchanged. Shape (batch, n_out) where n_out depends on the number of periodic axes.
- Return type:
y (jnp.array)
Example
>>> period_axes = {1: (jnp.pi, False)} >>> embedder = PeriodEmbedder(period_axes) >>> x = jnp.array([[1.0, 0.5], [2.0, 1.0]]) >>> y = embedder(x) # Shape: (2, 3) - axis 0 unchanged, axis 1 → [cos, sin]
- class jaxkan.models.utils.RFFEmbedder(std=1.0, n_in=1, embed_dim=256, seed=42)[source]
Bases:
ModuleRandom Fourier Features (RFF) embedding module for nonlinear feature transformation.
- B
Random projection matrix, shape (n_in, embed_dim//2).
- Type:
nnx.Param
- __init__(std=1.0, n_in=1, embed_dim=256, seed=42)[source]
Initializes a RFFEmbedder module.
- Parameters:
std (float) – Standard deviation for the normal distribution used to initialize the random projection matrix.
n_in (int) – Input dimension.
embed_dim (int) – Output embedding dimension. Must be even (actual dimension used is embed_dim//2 for the random matrix).
seed (int) – Random seed for reproducible initialization.
Example
>>> embedder = RFFEmbedder(std=1.0, n_in=2, embed_dim=256, seed=42)
- __call__(x)[source]
Applies Random Fourier Features transformation to the input.
- Parameters:
x (jnp.array) – Input array, shape (batch, n_in).
- Returns:
Embedded output using random Fourier features: [cos(xB), sin(xB)]. Shape (batch, embed_dim).
- Return type:
y (jnp.array)
Example
>>> embedder = RFFEmbedder(std=1.0, n_in=2, embed_dim=256, seed=42) >>> x = jnp.array([[1.0, 0.5], [2.0, 1.0]]) >>> y = embedder(x) # Shape: (2, 256)
- jaxkan.models.utils.count_params(model)[source]
Count the total number of trainable parameters in a model.
- Parameters:
model (nnx.Module) – Flax model instance.
- Returns:
Total number of trainable parameters in the model.
- Return type:
total_params (int)
Example
>>> model = KAN([2,8,8,1], 'spline', {'k': 4, 'G': 3}, 42) >>> num_params = count_params(model) >>> print(f"Model has {num_params} parameters")
- jaxkan.models.utils.get_frob(model, x)[source]
Compute the squared Frobenius norm of the model’s gradient at a given input point.
- Parameters:
model (nnx.Module) – Flax model instance.
x (jnp.array) – Input point, shape (d,) or (1, d).
- Returns:
Squared Frobenius norm of the gradient ||∇f(x)||²_F.
- Return type:
fro_sq (float)
Example
>>> model = KAN([2,8,1], 'spline', {}, 42) >>> x = jnp.array([0.5, 0.3]) >>> frob_norm_sq = get_frob(model, x)
- jaxkan.models.utils.batched_frob(*args, **kwargs)
- jaxkan.models.utils.get_complexity(model, pde_collocs, bc_collocs=None)[source]
Compute model complexity as the average squared Frobenius norm of gradients over collocation points.
- Parameters:
model (nnx.Module) – Flax model instance.
pde_collocs (jnp.array) – Collocation points for PDE/equation domain, shape (N, d).
bc_collocs (jnp.array, optional) – Initial/boundary condition collocation points, shape (M, d). If None, only use collocs.
- Returns:
Average squared Frobenius norm of gradients: mean(||∇f(x)||²_F).
- Return type:
complexity (float)
Example
>>> model = KAN([2,8,1], 'spline', {}, 42) >>> pde_collocs = jnp.array([[0.5, 0.3], [0.2, 0.7]]) >>> bc_collocs = jnp.array([[0.0, 0.5]]) >>> complexity = get_complexity(model, pde_collocs, bc_collocs)
- jaxkan.models.utils.get_adam(learning_rate=0.001, schedule_type=None, decay_steps=5000, decay_rate=0.9, warmup_steps=0, staircase=False, b1=0.9, b2=0.999, eps=1e-08, **schedule_kwargs)[source]
Create an Adam optimizer with optional learning rate scheduling and warmup.
- Parameters:
learning_rate (float) – Base learning rate. Default is 1e-3.
schedule_type (str, optional) – Type of learning rate schedule. Options: - None: Constant learning rate (default) - ‘exponential’: Exponential decay schedule - ‘cosine’: Cosine annealing schedule - ‘polynomial’: Polynomial decay schedule - ‘piecewise_constant’: Piecewise constant schedule (requires ‘boundaries’ and ‘values’ in schedule_kwargs)
decay_steps (int) – Number of steps for the learning rate decay schedule. Default is 5000. Used for exponential, cosine, and polynomial schedules.
decay_rate (float) – Decay rate for exponential schedule. Default is 0.9. For polynomial schedule, this is the ‘power’ parameter.
warmup_steps (int) – Number of warmup steps with linear learning rate increase from 0 to learning_rate. Default is 0 (no warmup).
staircase (bool) – If True, decay the learning rate at discrete intervals (staircase function). Default is False (smooth decay).
b1 (float) – Exponential decay rate for first moment. Default is 0.9.
b2 (float) – Exponential decay rate for second moment. Default is 0.999.
eps (float) – Small constant for numerical stability. Default is 1e-8.
**schedule_kwargs –
Additional keyword arguments for specific schedules. For piecewise_constant schedule:
boundaries (list): List of step boundaries
values (list): List of learning rate values (must be len(boundaries) + 1)
- Returns:
Configured Adam optimizer with learning rate schedule.
- Return type:
optax.GradientTransformation
Example
>>> # Adam with exponential decay and warmup >>> optimizer = get_adam( ... learning_rate=1e-3, ... schedule_type='exponential', ... decay_steps=5000, ... decay_rate=0.9, ... warmup_steps=1000, ... b1=0.9, ... b2=0.999 ... )
>>> # Adam with cosine annealing >>> optimizer = get_adam( ... learning_rate=1e-3, ... schedule_type='cosine', ... decay_steps=10000, ... warmup_steps=500 ... )
>>> # Adam with constant learning rate >>> optimizer = get_adam(learning_rate=1e-3)
- jaxkan.models.utils.get_lbfgs(learning_rate=None, memory_size=10, scale_init_precond=True, linesearch=None)[source]
Create an L-BFGS optimizer.
Note: L-BFGS requires special handling when used with Flax NNX. You must pass value, value_fn, and model to the optimizer’s update method. The value_fn should be a function that takes the model and returns the loss value.
- Parameters:
learning_rate (float, optional) – Initial learning rate. If None, the optimizer uses its own line search to determine the step size. Default is None.
memory_size (int) – Number of past updates to keep in memory to approximate the Hessian inverse. Larger values require more memory but may lead to better convergence. Default is 10.
scale_init_precond (bool) – Whether to use a scaled identity as the initial preconditioner. Default is True.
linesearch (optax.GradientTransformation, optional) – Custom line search transformation. If None, uses the default zoom line search. Default is None.
- Returns:
Configured L-BFGS optimizer.
- Return type:
optax.GradientTransformationExtraArgs
PIKAN Training
- jaxkan.pikan.pde.get_ac_res(D=0.0001, c=5.0)[source]
Returns the Allen-Cahn equation residual function (2D).
- Parameters:
D (float) – Diffusion coefficient. Default is 1e-4.
c (float) – Reaction coefficient. Default is 5.0.
- Returns:
Residual function for the Allen-Cahn equation.
- Return type:
ac_res (function)
Example
>>> # Use default parameters >>> ac_res = get_ac_res() >>> >>> # Use custom parameters >>> ac_res = get_ac_res(D=0.001, c=10.0)
- jaxkan.pikan.pde.get_diffusion_res(D=0.25)[source]
Returns the diffusion equation residual function (2D).
- Parameters:
D (float) – Diffusion coefficient. Default is 0.25.
- Returns:
Residual function for the diffusion equation.
- Return type:
diffusion_res (function)
Example
>>> # Use default parameters >>> diffusion_res = get_diffusion_res() >>> >>> # Use custom parameters >>> diffusion_res = get_diffusion_res(D=0.5)
- jaxkan.pikan.pde.get_burgers_res(nu=0.003183098861837907)[source]
Returns the Burgers equation residual function (2D).
- Parameters:
nu (float) – Viscosity coefficient. Default is 0.01/π.
- Returns:
Residual function for the Burgers equation.
- Return type:
burgers_res (function)
Example
>>> # Use default parameters >>> burgers_res = get_burgers_res() >>> >>> # Use custom parameters >>> burgers_res = get_burgers_res(nu=0.001)
- jaxkan.pikan.pde.get_kdv_res(eta=1.0, mu=0.022)[source]
Returns the Korteweg-de Vries equation residual function (2D).
- Parameters:
eta (float) – Nonlinearity coefficient. Default is 1.0.
mu (float) – Dispersion coefficient. Default is 0.022.
- Returns:
Residual function for the Korteweg-de Vries equation.
- Return type:
kdv_res (function)
Example
>>> # Use default parameters >>> kdv_res = get_kdv_res() >>> >>> # Use custom parameters >>> kdv_res = get_kdv_res(eta=2.0, mu=0.01)
- jaxkan.pikan.pde.get_sg_res(D=1.0)[source]
Returns the sine-Gordon equation residual function (2D).
- Parameters:
D (float) – Wave speed coefficient. Default is 1.0.
- Returns:
Residual function for the sine-Gordon equation.
- Return type:
sg_res (function)
Example
>>> # Use default parameters >>> sg_res = get_sg_res() >>> >>> # Use custom parameters >>> sg_res = get_sg_res(D=2.0)
- jaxkan.pikan.pde.get_advection_res(c=20.0)[source]
Returns the advection equation residual function (2D).
- Parameters:
c (float) – Wave speed coefficient. Default is 20.0.
- Returns:
Residual function for the advection equation.
- Return type:
advection_res (function)
Example
>>> # Use default parameters >>> advection_res = get_advection_res() >>> >>> # Use custom parameters >>> advection_res = get_advection_res(c=10.0)
- jaxkan.pikan.pde.get_helmholtz_res(a1=1.0, a2=4.0, k=1.0)[source]
Returns the Helmholtz equation residual function (2D).
- Parameters:
a1 (float) – First frequency parameter for source term. Default is 1.0.
a2 (float) – Second frequency parameter for source term. Default is 4.0.
k (float) – Wave number. Default is 1.0.
- Returns:
Residual function for the Helmholtz equation.
- Return type:
helmholtz_res (function)
Example
>>> # Use default parameters >>> helmholtz_res = get_helmholtz_res() >>> >>> # Use custom parameters >>> helmholtz_res = get_helmholtz_res(a1=2.0, a2=3.0, k=2.0)
- jaxkan.pikan.pde.get_poisson_res(a1=4.0, a2=4.0)[source]
Returns the Poisson equation residual function (2D).
- Parameters:
a1 (float) – First frequency parameter for source term. Default is 4.0.
a2 (float) – Second frequency parameter for source term. Default is 4.0.
- Returns:
Residual function for the Poisson equation.
- Return type:
poisson_res (function)
Example
>>> # Use default parameters >>> poisson_res = get_poisson_res() >>> >>> # Use custom parameters >>> poisson_res = get_poisson_res(a1=2.0, a2=3.0)
- jaxkan.pikan.pde.get_wave_res(c=4.0)[source]
Returns the wave equation residual function (2D).
- Parameters:
c (float) – Wave speed coefficient. Default is 4.0.
- Returns:
Residual function for the wave equation.
- Return type:
wave_res (function)
Example
>>> # Use default parameters >>> wave_res = get_wave_res() >>> >>> # Use custom parameters >>> wave_res = get_wave_res(c=2.0)
- jaxkan.pikan.pde.get_ks_res(alpha=6.25, beta=0.390625, gamma=0.00152587890625)[source]
Returns the Kuramoto-Sivashinsky equation residual function (2D).
- Parameters:
alpha (float) – Nonlinearity coefficient. Default is 100/16.
beta (float) – Second-order dispersion coefficient. Default is 100/(16²).
gamma (float) – Fourth-order dispersion coefficient. Default is 100/(16⁴).
- Returns:
Residual function for the Kuramoto-Sivashinsky equation.
- Return type:
ks_res (function)
Example
>>> # Use default parameters >>> ks_res = get_ks_res() >>> >>> # Use custom parameters >>> ks_res = get_ks_res(alpha=6.25, beta=0.390625, gamma=0.0009765625)
- jaxkan.pikan.sampling.get_collocs_grid(ranges)[source]
Generate grid-based collocation points across arbitrary dimensions.
- Parameters:
ranges (list[tuple[float, float, int]]) – List of tuples where each tuple contains (lower_bound, upper_bound, sample_size) for each dimension.
- Returns:
Array of shape (total_points, n_dims) containing all grid points.
- Return type:
jnp.array
Example
>>> # 2D grid: x in [0, 1] with 10 points, t in [0, 2] with 20 points >>> collocs = pde_collocs_grid(ranges=[(0.0, 1.0, 10), (0.0, 2.0, 20)])
- jaxkan.pikan.sampling.get_collocs_random(ranges, total_points, seed=42)[source]
Generate random collocation points across arbitrary dimensions.
- Parameters:
ranges (list[tuple[float, float]]) – List of tuples where each tuple contains (lower_bound, upper_bound) for each dimension.
total_points (int) – Total number of random points to generate.
seed (int) – Random seed for reproducibility.
- Returns:
Array of shape (total_points, n_dims) containing randomly sampled points.
- Return type:
jnp.array
Example
>>> # 2D random samples: 100 points in [0, 1] x [0, 2] >>> collocs = pde_collocs_random(ranges=[(0.0, 1.0), (0.0, 2.0)], total_points=100, seed=42)
- jaxkan.pikan.sampling.get_collocs_sobol(ranges, total_points, seed=42)[source]
Generate Sobol quasi-random collocation points across arbitrary dimensions.
- Parameters:
ranges (list[tuple[float, float]]) – List of tuples where each tuple contains (lower_bound, upper_bound) for each dimension.
total_points (int) – Total number of Sobol points to generate. Should be a power of 2 for optimal coverage.
seed (int) – Random seed for reproducibility. Default is 42.
- Returns:
Array of shape (total_points, n_dims) containing Sobol sampled points.
- Return type:
jnp.array
Example
>>> # 2D Sobol samples: 1024 points in [0, 1] x [0, 2] >>> collocs = pde_collocs_sobol(ranges=[(0.0, 1.0), (0.0, 2.0)], total_points=1024, seed=42)
- jaxkan.pikan.adaptive.get_colloc_indices(collocs_pool, batch_size, px, seed)[source]
Sample collocation point indices from a pool based on probability weights and sort by first coordinate (time).
- Parameters:
collocs_pool (jnp.array) – Pool of collocation points, shape (N, n_dims).
batch_size (int) – Number of points to sample.
px (jnp.array) – Probability weights for sampling, shape (N,).
seed (int) – Random seed for reproducibility.
- Returns:
Indices of sampled points sorted by first coordinate, shape (batch_size,).
- Return type:
sorted_pool_indices (jnp.array)
Example
>>> collocs_pool = jnp.array([[0.5, 0.1], [0.2, 0.3], [0.8, 0.7]]) >>> px = jnp.array([0.5, 0.3, 0.2]) >>> indices = get_colloc_indices(collocs_pool, batch_size=2, px=px, seed=42)
- jaxkan.pikan.adaptive.update_rba_weights(residuals, weights, gamma, eta, eps=1e-12)[source]
Update residual-based attention weights from the current residual magnitudes.
- Parameters:
residuals (jnp.array) – Residual values, typically shape (N, 1).
weights (jnp.array) – Current RBA weights with the same shape as
residuals.gamma (float) – Exponential moving average coefficient.
eta (float) – Residual scaling coefficient.
eps (float) – Numerical stabilizer for normalization.
- Returns:
Updated RBA weights.
- Return type:
jnp.array
- jaxkan.pikan.adaptive.apply_rba_weights(residuals, weights)[source]
Apply RBA weights to residuals while keeping the weights outside the backward graph.
- Parameters:
residuals (jnp.array) – Residual values.
weights (jnp.array) – RBA weights with matching shape.
- Returns:
Weighted residuals.
- Return type:
jnp.array
- jaxkan.pikan.adaptive.get_causal_weights(losses, causal_matrix, causal_tol)[source]
Compute causal training weights and stop their gradients.
- Parameters:
losses (jnp.array) – Chunk-wise loss values.
causal_matrix (jnp.array) – Upper-triangular causal coupling matrix.
causal_tol (float) – Exponential weighting coefficient.
- Returns:
Stopped-gradient causal weights.
- Return type:
jnp.array
- jaxkan.pikan.adaptive.get_rad_probabilities(weighted_residuals, rad_a, rad_c, eps=1e-12)[source]
Convert weighted residuals into normalized RAD sampling probabilities.
- Parameters:
weighted_residuals (jnp.array) – Residuals already multiplied by their current adaptive weights.
rad_a (float) – Density-control exponent.
rad_c (float) – Baseline offset added to the density.
eps (float) – Numerical stabilizer for normalization.
- Returns:
Probability vector of shape (N,).
- Return type:
jnp.array
- jaxkan.pikan.adaptive.get_rad_indices(collocs_pool, residuals, old_indices, batch_weights, pool_weights, batch_size, rad_a, rad_c, seed, eps=1e-12)[source]
Update pool weights and draw a new RAD-resampled batch of collocation indices.
- Parameters:
collocs_pool (jnp.array) – Full collocation pool, shape (N, d).
residuals (jnp.array) – Residual values over the full pool, shape (N, 1).
old_indices (jnp.array) – Previously active batch indices.
batch_weights (jnp.array) – Current adaptive weights for the active batch.
pool_weights (jnp.array) – Full pool of adaptive weights.
batch_size (int) – Number of indices to sample.
rad_a (float) – Density-control exponent.
rad_c (float) – Baseline offset added to the density.
seed (int) – Sampling seed.
eps (float) – Numerical stabilizer.
- Returns:
Sampled indices, updated full-pool weights, and normalized probabilities.
- Return type:
tuple[jnp.array, jnp.array, jnp.array]
- jaxkan.pikan.adaptive.lr_anneal(grads, lambdas, grad_mixing)[source]
Perform the learning rate annealing algorithm introduced in “Understanding and mitigating gradient pathologies in physics-informed neural networks”
- Parameters:
grads (tuple | list) – Tuple/list of gradient pytrees, one per loss term.
lambdas (tuple | list) – Tuple/list of current global loss weights, matching
gradsin order.grad_mixing (float) – Exponential moving average coefficient.
- Returns:
Updated global loss weights in the same logical order as the inputs.
- Return type:
tuple
Example
>>> λs_new = lr_anneal((grads_E, grads_B, grads_aux), (1.0, 1.0, 1.0), 0.9)
- jaxkan.pikan.utils.model_eval(model, coords, refsol)[source]
Compute the relative L2 error between model predictions and reference solution.
- Parameters:
model (nnx.Module) – Flax model instance.
coords (jnp.array) – Input coordinates for evaluation, shape (N, n_dims).
refsol (jnp.array) – Reference solution values for comparison.
- Returns:
Relative L2 error: ||prediction - reference|| / ||reference||
- Return type:
l2err (float)
Example
>>> model = KAN([2,8,1], 'spline', {}, 42) >>> coords = jnp.array([[0.5, 0.5], [0.1, 0.2]]) >>> refsol = jnp.array([[0.25], [0.02]]) >>> error = model_eval(model, coords, refsol)