API Reference

KAN Layer Classes

jaxkan.layers.__init__.get_layer(layer_type)[source]

Helper method that creates a mapping between layer type codes and the actual classes.

Parameters:

layer_type (str) – Code of layer to be used.

Returns:

A jaxkan.layers layer class instance to be used as the building block of a KAN.

Return type:

layer (jaxkan.layers.Layer)

Example

>>> LayerClass = get_layer("base")
class jaxkan.layers.Spline.BaseLayer(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]

Bases: Module

BaseLayer class. Corresponds to the original spline-based KAN Layer introduced in the original version of KAN. Ref: https://arxiv.org/abs/2404.19756

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

k

Order of the spline basis functions.

Type:

int

residual

Function that is applied on samples to calculate residual activation.

Type:

Union[nnx.Module, None]

rngs

Random number generator state.

Type:

nnx.Rngs

grid

Grid object for spline basis functions.

Type:

BaseGrid

c_spl

Spline weights if external_weights is True, else None.

Type:

Union[nnx.Param, None]

c_basis

Trainable coefficients for the basis functions.

Type:

nnx.Param

c_res

Trainable coefficients for residual activation if residual is not None.

Type:

Union[nnx.Param, None]

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

__init__(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]

Initializes a BaseLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • k (int) – Order of the spline basis functions.

  • G (int) – Number of grid intervals.

  • grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.

  • grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

  • residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.

  • external_weights (bool) – Boolean that controls if the trainable weights of shape (n_out, n_in) applied to the splines should be used.

  • init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3,
>>>                   G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                   external_weights = True, init_scheme = None, add_bias = True,
>>>                   seed = 42)
basis(x)[source]

Uses k and the current grid to calculate the values of spline basis functions on the input.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Spline basis functions applied on inputs, shape (n_in*n_out, G+k, batch).

Return type:

basis_splines (jnp.array)

Example

>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3,
>>>                   G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                   external_weights = True, init_scheme = None, add_bias = True,
>>>                   seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer.basis(x_batch)
update_grid(x, G_new)[source]

Performs a grid update given a new value for G (i.e., G_new) and adapts it to the given data, x. Additionally, re-initializes the c_basis parameters to a better estimate, based on the new grid.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • G_new (int) – Size of the new grid (in terms of intervals).

Example

>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3,
>>>                   G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                   external_weights = True, init_scheme = None, add_bias = True,
>>>                   seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> layer.update_grid(x=x_batch, G_new=5)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, corresponding to the weighted sum of the B-spline activation and the residual activation, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = BaseLayer(n_in = 2, n_out = 5, k = 3,
>>>                   G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                   external_weights = True, init_scheme = None, add_bias = True,
>>>                   seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.Spline.SplineLayer(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]

Bases: Module

SplineLayer class. Corresponds to the “efficient” version of the spline-based KAN Layer. Ref: https://github.com/Blealtan/efficient-kan

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

k

Order of the spline basis functions.

Type:

int

residual

Function that is applied on samples to calculate residual activation.

Type:

Union[nnx.Module, None]

rngs

Random number generator state.

Type:

nnx.Rngs

grid

Grid object for spline basis functions.

Type:

SplineGrid

c_spl

Spline weights if external_weights is True, else None.

Type:

Union[nnx.Param, None]

c_basis

Trainable coefficients for the basis functions.

Type:

nnx.Param

c_res

Trainable coefficients for residual activation if residual is not None.

Type:

Union[nnx.Param, None]

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

__init__(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05, residual=<PjitFunction of <function silu>>, external_weights=True, init_scheme=None, add_bias=True, seed=42)[source]

Initializes a BaseLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • k (int) – Order of the spline basis functions.

  • G (int) – Number of grid intervals.

  • grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.

  • grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

  • residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.

  • external_weights (bool) – Boolean that controls if the trainable weights of shape (n_out, n_in) applied to the splines should be used.

  • init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3,
>>>                     G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                     external_weights = True, init_scheme = None, add_bias = True,
>>>                     seed = 42)
basis(x)[source]

Uses k and the current grid to calculate the values of spline basis functions on the input.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Spline basis functions applied on inputs, shape (n_in*n_out, G+k, batch).

Return type:

basis_splines (jnp.array)

Example

>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3,
>>>                     G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                     external_weights = True, init_scheme = None, add_bias = True,
>>>                     seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer.basis(x_batch)
update_grid(x, G_new)[source]

Performs a grid update given a new value for G (i.e., G_new) and adapts it to the given data, x. Additionally, re-initializes the c_basis parameters to a better estimate, based on the new grid.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • G_new (int) – Size of the new grid (in terms of intervals).

Example

>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3,
>>>                     G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                     external_weights = True, init_scheme = None, add_bias = True,
>>>                     seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> layer.update_grid(x=x_batch, G_new=5)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, corresponding to the weighted sum of the B-spline activation and the residual activation, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = SplineLayer(n_in = 2, n_out = 5, k = 3,
>>>                     G = 3, grid_range = (-1,1), grid_e = 0.05, residual = nnx.silu,
>>>                     external_weights = True, init_scheme = None, add_bias = True,
>>>                     seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.Chebyshev.ChebyshevLayer(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Bases: Module

ChebyshevLayer class. Corresponds to the Chebyshev version of KANs and comes in three “flavors”:

“default”: the version presented in https://arxiv.org/pdf/2405.07200 “modified”: the version presented in https://www.sciencedirect.com/science/article/pii/S0045782524005462 “exact”: uses pre-defined functions for higher efficiency, but cannot scale up to arbitrary degrees

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

D

Degree of Chebyshev polynomial (1st kind).

Type:

int

flavor

One of “default”, “modified”, or “exact” - chooses basis implementation.

Type:

Union[str, None]

residual

Function that is applied on samples to calculate residual activation.

Type:

Union[nnx.Module, None]

rngs

Random number generator state.

Type:

nnx.Rngs

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

c_ext

External weights if external_weights is True, else None.

Type:

Union[nnx.Param, None]

c_basis

Trainable coefficients for the basis functions.

Type:

nnx.Param

c_res

Trainable coefficients for residual activation if residual is not None.

Type:

Union[nnx.Param, None]

__init__(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Initializes a ChebyshevLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • D (int) – Degree of Chebyshev polynomial (1st kind).

  • flavor (Union[str, None]) – One of “default”, “modified”, or “exact” - chooses basis implementation.

  • residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.

  • external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.

  • init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                        residual = None, external_weights = False, init_scheme = None,
>>>                        add_bias = True, seed = 42)
basis(x)[source]

Based on the degree and flavor, the values of the Chebyshev basis functions are calculated on the input.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Chebyshev basis functions applied on inputs, shape (batch, n_in, D+1).

Return type:

cheb (jnp.array)

Example

>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                        residual = None, external_weights = False, init_scheme = None,
>>>                        add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer.basis(x_batch)
update_grid(x, D_new)[source]

For the case of ChebyKANs there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the degree of the polynomials.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • D_new (int) – New Chebyshev polynomial degree.

Example

>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                        residual = None, external_weights = False, init_scheme = None,
>>>                        add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> layer.update_grid(x=x_batch, D_new=8)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = ChebyshevLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                        residual = None, external_weights = False, init_scheme = None,
>>>                        add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.Legendre.LegendreLayer(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Bases: Module

LegendreLayer class. Corresponds to the Legendre version of KANs and comes in two “flavors”:

“default”: uses the recursion formula to calculate polynomials up to arbitrary degree “exact”: uses pre-defined functions for higher efficiency, but cannot scale up to arbitrary degrees

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

D

Degree of Legendre polynomial.

Type:

int

flavor

One of “default” or “exact” - chooses basis implementation.

Type:

str

residual

Function that is applied on samples to calculate residual activation.

Type:

Union[nnx.Module, None]

rngs

Random number generator state.

Type:

nnx.Rngs

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

c_ext

External weights if external_weights is True, else None.

Type:

Union[nnx.Param, None]

c_basis

Trainable coefficients for the basis functions.

Type:

nnx.Param

c_res

Trainable coefficients for residual activation if residual is not None.

Type:

Union[nnx.Param, None]

__init__(n_in=2, n_out=5, D=5, flavor=None, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Initializes a LegendreLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • D (int) – Degree of Legendre polynomial.

  • flavor (Union[str, None]) – One of “default”, “modified”, or “exact” - chooses basis implementation.

  • residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.

  • external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.

  • init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                       residual = None, external_weights = False, init_scheme = None,
>>>                       add_bias = True, seed = 42)
basis(x)[source]

Based on the degree and flavor, the values of the Legendre basis functions are calculated on the input.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Legendre basis functions applied on inputs, shape (batch, n_in, D+1).

Return type:

leg (jnp.array)

Example

>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                       residual = None, external_weights = False, init_scheme = None,
>>>                       add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer.basis(x_batch)
update_grid(x, D_new)[source]

For the case of LegendreKANs there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the degree of the polynomials.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • D_new (int) – New Legendre polynomial degree.

Example

>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                       residual = None, external_weights = False, init_scheme = None,
>>>                       add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> layer.update_grid(x=x_batch, D_new=8)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = LegendreLayer(n_in = 2, n_out = 5, D = 5, flavor = "default",
>>>                       residual = None, external_weights = False, init_scheme = None,
>>>                       add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.Fourier.FourierLayer(n_in=2, n_out=5, D=5, smooth_init=True, add_bias=True, seed=42)[source]

Bases: Module

FourierLayer class. Corresponds to the Fourier-based version of KANs (FourierKAN). Ref: https://github.com/GistNoesis/FourierKAN

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

D

Order of Fourier sum.

Type:

int

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

c_cos

Trainable cosine coefficients.

Type:

nnx.Param

c_sin

Trainable sine coefficients.

Type:

nnx.Param

__init__(n_in=2, n_out=5, D=5, smooth_init=True, add_bias=True, seed=42)[source]

Initializes a FourierLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • D (int) – Order of Fourier sum.

  • smooth_init (bool) – Whether to initialize Fourier coefficients with smoothening.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42)
basis(x)[source]

Calculates the con/sin activations on the input x.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Cosines, sines applied on inputs, shape (batch, n_in, D).

Return type:

c, s (tuple)

Example

>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output_1, output_2 = layer.basis(x_batch)
update_grid(x, D_new)[source]

For the case of FourierKAN there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the number of summands.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • D_new (int) – New value for the fourier sum’s order.

Example

>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> layer.update_grid(x=x_batch, D_new=8)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = FourierLayer(n_in = 2, n_out = 5, D = 5, smooth_init = True, add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.RBF.RBFLayer(n_in=2, n_out=5, D=5, kernel=None, grid_range=(-2.0, 2.0), grid_e=1.0, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Bases: Module

RBFLayer class. Corresponds to the RBF version of KANs using different kernels.

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

D

Number of basis functions.

Type:

int

kernel

Kernel configuration for the RBFs.

Type:

dict

residual

Function that is applied on samples to calculate residual activation.

Type:

Union[nnx.Module, None]

rngs

Random number generator state.

Type:

nnx.Rngs

grid

Grid object for RBF basis functions.

Type:

RBFGrid

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

c_ext

External weights if external_weights is True, else None.

Type:

Union[nnx.Param, None]

c_basis

Trainable coefficients for the basis functions.

Type:

nnx.Param

c_res

Trainable coefficients for residual activation if residual is not None.

Type:

Union[nnx.Param, None]

__init__(n_in=2, n_out=5, D=5, kernel=None, grid_range=(-2.0, 2.0), grid_e=1.0, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Initializes a RBFLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • D (int) – Number of basis functions.

  • kernel (Union[dict, None]) – Kernel to be used for the RBFs.

  • grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.

  • grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

  • residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.

  • external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.

  • init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian",
>>>                  grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None,
>>>                  external_weights = False, init_scheme = None, add_bias = True,
>>>                  seed = 42)
basis(x)[source]

Based on the degree and kernel, the values of the radial basis functions are calculated on the input.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Radial basis functions applied on inputs, shape (batch, n_in, D).

Return type:

rbf (jnp.array)

Example

>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian",
>>>                  grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None,
>>>                  external_weights = False, init_scheme = None, add_bias = True,
>>>                  seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-2.0, maxval=2.0)
>>>
>>> output = layer.basis(x_batch)
update_grid(x, D_new)[source]

Performs a grid update given a new value for D (i.e., D_new) and adapts it to the given data, x. Additionally, re-initializes the c_basis parameters to a better estimate, based on the new grid.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • D_new (int) – New number of spline basis functions.

Example

>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian",
>>>                  grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None,
>>>                  external_weights = False, init_scheme = None, add_bias = True,
>>>                  seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-2.0, maxval=2.0)
>>>
>>> layer.update_grid(x=x_batch, D_new=8)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = RBFLayer(n_in = 2, n_out = 5, D = 4, kernel = "gaussian",
>>>                  grid_range = (-2.0, 2.0), grid_e = 1.0, residual = None,
>>>                  external_weights = False, init_scheme = None, add_bias = True,
>>>                  seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-2.0, maxval=2.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.Sine.SineLayer(n_in=2, n_out=5, D=5, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Bases: Module

SineLayer class, inspired from the sine basis functions introduced in https://arxiv.org/pdf/2410.01990

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

D

Number of basis functions.

Type:

int

residual

Function that is applied on samples to calculate residual activation.

Type:

Union[nnx.Module, None]

rngs

Random number generator state.

Type:

nnx.Rngs

bias

Bias parameter if add_bias is True, else None.

Type:

Union[nnx.Param, None]

omega

Trainable frequency parameters.

Type:

nnx.Param

phase

Trainable phase parameters.

Type:

nnx.Param

c_ext

External weights if external_weights is True, else None.

Type:

Union[nnx.Param, None]

c_basis

Trainable coefficients for the basis functions.

Type:

nnx.Param

c_res

Trainable coefficients for residual activation if residual is not None.

Type:

Union[nnx.Param, None]

__init__(n_in=2, n_out=5, D=5, residual=None, external_weights=False, init_scheme=None, add_bias=True, seed=42)[source]

Initializes a SineLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • D (int) – Number of basis functions.

  • residual (Union[nnx.Module, None]) – Function that is applied on samples to calculate residual activation.

  • external_weights (bool) – Boolean that controls if the trainable weights (n_out, n_in) should be applied to the activations.

  • init_scheme (Union[dict, None]) – Dictionary that defines how the trainable parameters of the layer are initialized.

  • add_bias (bool) – Boolean that controls wether bias terms are also included during the forward pass or not.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False,
>>>                   init_scheme = None, add_bias = True, seed = 42)
basis(x)[source]

Calculates the application of the sine basis functions on the input.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Sine basis functions applied on inputs, shape (batch, n_in, D).

Return type:

B (jnp.array)

Example

>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False,
>>>                   init_scheme = None, add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer.basis(x_batch)
update_grid(x, D_new)[source]

For the case of sine-based KANs there is no concept of grid. However, a fine-graining approach can be followed by progressively increasing the number of basis functions, and by extension phases and omegas.

Parameters:
  • x (jnp.array) – Inputs, shape (batch, n_in).

  • D_new (int) – New number of basis functions.

Example

>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False,
>>>                   init_scheme = None, add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> layer.update_grid(x=x_batch, D_new=8)
__call__(x)[source]

The layer’s forward pass.

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = SineLayer(n_in = 2, n_out = 5, D = 5, residual = None, external_weights = False,
>>>                   init_scheme = None, add_bias = True, seed = 42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.layers.Dense.DenseLayer(n_in, n_out, activation=None, RWF={'mean': 1.0, 'std': 0.1}, add_bias=True, seed=42)[source]

Bases: Module

Dense layer with random weight factorization (RWF) for use in MLP architectures.

Note: This is not a KAN layer, but a standard MLP building block used in advanced KAN architectures like KKAN (see jaxkan.models module).

g

Scale factor vector of shape (n_out,) from the RWF reparameterization.

Type:

nnx.Param

v

Direction matrix of shape (n_in, n_out) from the RWF reparameterization.

Type:

nnx.Param

b

Bias vector of shape (n_out,), or None if add_bias is False.

Type:

nnx.Param or None

activation

Activation function applied after the linear transformation, or None.

Type:

callable or None

__init__(n_in, n_out, activation=None, RWF={'mean': 1.0, 'std': 0.1}, add_bias=True, seed=42)[source]

Initializes a Dense layer with RWF.

Parameters:
  • n_in (int) – Number of input features.

  • n_out (int) – Number of output features.

  • activation (callable, optional) – Activation function applied after the linear transformation. Defaults to None.

  • RWF (dict, optional) – Dictionary with keys 'mean' and 'std' controlling the log-normal scale of the RWF reparameterization. Defaults to {"mean": 1.0, "std": 0.1}.

  • add_bias (bool, optional) – Whether to include a learnable bias term. Defaults to True.

  • seed (int, optional) – Random seed for parameter initialization. Defaults to 42.

Example

>>> layer = DenseLayer(n_in=64, n_out=32, add_bias=True, seed=42)
__call__(x)[source]

Applies the dense layer to the input.

Parameters:

x (jnp.ndarray) – Input array of shape (batch, n_in).

Returns:

Output array of shape (batch, n_out).

Return type:

jnp.ndarray

Example

>>> layer = DenseLayer(n_in=4, n_out=2)
>>> x = jnp.ones((3, 4))
>>> y = layer(x)  # shape: (3, 2)
jaxkan.layers.utils.solve_single_lstsq(A_single, B_single)[source]

Simulates linalg.lstsq by reformulating the problem AX = B via the normal equations: (A^T A) X = A^T B. This is used instead of linalg.lstsq because it’s much faster.

Parameters:
  • A_single (jnp.array) – Matrix A of AX = B, shape (M, N).

  • B_single (jnp.array) – Matrix B of AX = B, shape (M, K).

Returns:

Matrix X of AX = B, shape (N, K).

Return type:

single_solution (jnp.array)

Example

>>> A = jnp.array([[2.0, 1.0], [1.0, 3.0]])
>>> B = jnp.array([[1.0], [2.0]])
>>>
>>> solution = solve_single_lstsq(A, B)
jaxkan.layers.utils.solve_full_lstsq(A_full, B_full)[source]

Parallelizes the single case, so that the problem can be solved for matrices with dimensions higher than 2. Essentially, solve_single_lstsq and solve_full_lstsq combined are a workaround, because (unlike PyTorch for example), JAX’s libraries do not support lstsq for dims > 2.

Parameters:
  • A_full (jnp.array) – Matrix A of AX = B, shape (batch, M, N).

  • B_full (jnp.array) – Matrix B of AX = B, shape (batch, M, K).

Returns:

Matrix X of AX = B, shape (batch, N, K).

Return type:

full_solution (jnp.array)

Example

>>> A = jnp.array([[[2.0, 1.0], [1.0, 3.0]], [[1.0, 2.0], [2.0, 1.0]]])
>>> B = jnp.array([[[1.0], [2.0]], [[2.0], [3.0]]])
>>>
>>> solution = solve_full_lstsq(A, B)
jaxkan.layers.utils.interpolate_moments(mu_old, nu_old, new_shape)[source]

Performs a linear interpolation to assign values to the first and second-order moments of gradients of the c_i basis functions coefficients after grid extension.

Parameters:
  • mu_old (jnp.array) – First-order moments before extension, shape (n_in*n_out, num_basis) or (n_out, n_in, num_basis).

  • nu_old (jnp.array) – Second-order moments before extension, shape (n_in*n_out, num_basis) or (n_out, n_in, num_basis).

  • new_shape (tuple) – The new desired shape, either (n_in*n_out, new_num_basis) or (n_out, n_in, new_num_basis).

Returns:

First- and second-order moments after extension, shape new_shape.

Return type:

mu_new, nu_new (tuple)

Example

>>> mu_old = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> nu_old = jnp.array([[7, 8, 9], [10, 11, 12]])
>>> new_shape = (2, 5)
>>>
>>> mu_new, nu_new = interpolate_moments(mu_old, nu_old, new_shape)
jaxkan.layers.utils.adam_transition(old_state, model_state)[source]

Performs the state transition for the Adam optimizer with scheduler after grid extension. Note that the transition happens in-place, i.e. nothing is returned, the optimizer is simply transitioned from the old state to the new.

Parameters:
  • old_state (tuple) – Collection of Adam state and scheduler state before extension.

  • model_state (dict) – Dict of KAN model state after split.

Example

>>> old_state = optimizer.opt_state
>>> _, model_state = nnx.split(model)
>>> adam_transition(old_state, model_state)

KAN Grid Classes

class jaxkan.grids.BaseGrid.BaseGrid(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]

Bases: object

BaseGrid class, corresponding to the grid of the BaseLayer class. It comprises an initialization as well as an update procedure.

n_in

Number of layer’s incoming nodes.

Type:

int

n_out

Number of layer’s outgoing nodes.

Type:

int

k

Order of the spline basis functions.

Type:

int

G

Number of grid intervals.

Type:

int

grid_range

An initial range for the grid’s ends, although adaptivity can completely change it.

Type:

tuple

grid_e

Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

Type:

float

item

The actual grid array, shape (n_in*n_out, G + 2k + 1).

Type:

jnp.array

__init__(n_in=2, n_out=5, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]

Initializes a BaseGrid instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • k (int) – Order of the spline basis functions.

  • G (int) – Number of grid intervals.

  • grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.

  • grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

Example

>>> grid_type = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid = grid_type.item
update(x, G_new)[source]

Update the grid based on input data and new grid size.

Parameters:
  • x (jnp.ndarray) – Input data, shape (batch, n_in).

  • G_new (int) – New grid size in terms of intervals.

Example

>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> grid = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid.update(x=x_batch, G_new=5)
class jaxkan.grids.SplineGrid.SplineGrid(n_nodes=2, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]

Bases: object

SplineGrid class, corresponding to the grid of the SplineLayer class. It comprises an initialization as well as an update procedure.

n_nodes

Number of layer nodes.

Type:

int

k

Order of the spline basis functions.

Type:

int

G

Number of grid intervals.

Type:

int

grid_range

An initial range for the grid’s ends, although adaptivity can completely change it.

Type:

tuple

grid_e

Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

Type:

float

item

The actual grid array, shape (n_nodes, G + 2k + 1).

Type:

jnp.array

__init__(n_nodes=2, k=3, G=3, grid_range=(-1, 1), grid_e=0.05)[source]

Initializes a SplineGrid instance.

Parameters:
  • n_nodes (int) – Number of layer nodes.

  • k (int) – Order of the spline basis functions.

  • G (int) – Number of grid intervals.

  • grid_range (tuple) – An initial range for the grid’s ends, although adaptivity can completely change it.

  • grid_e (float) – Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

Example

>>> grid_type = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid = grid_type.item
update(x, G_new)[source]

Update the grid based on input data and new grid size.

Parameters:
  • x (jnp.ndarray) – Input data, shape (batch, n_nodes).

  • G_new (int) – New grid size in terms of intervals.

Example

>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> grid = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid.update(x=x_batch, G_new=5)
class jaxkan.grids.RBFGrid.RBFGrid(n_nodes=2, D=3, grid_range=(-2.0, 2.0), grid_e=1.0)[source]

Bases: object

RBFGrid class, corresponding to the grid of the RBFLayer class. It comprises an initialization as well as an update procedure.

n_nodes

Number of layer nodes.

Type:

int

D

Number of radial basis functions.

Type:

int

grid_range

The range of the grid’s ends, on which the basis functions are defined.

Type:

tuple

grid_e

Parameter that defines if the grid is uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

Type:

float

item

The actual grid array, shape (n_nodes, D).

Type:

jnp.array

__init__(n_nodes=2, D=3, grid_range=(-2.0, 2.0), grid_e=1.0)[source]

Initializes a RBFGrid instance.

Parameters:
  • n_nodes (int) – Number of layer nodes.

  • D (int) – Number of radial basis functions.

  • grid_range (tuple) – The range of the grid’s ends, on which the basis functions are defined.

  • grid_e (float) – Parameter that defines if the grid is uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.

Example

>>> grid_type = RBFGrid(n_nodes = 2, D = 4, grid_range = (-2.0, 2.0), grid_e = 1.0)
>>> grid = grid_type.item
update(x, D_new)[source]

Update the grid based on input data and new grid size.

Parameters:
  • x (jnp.ndarray) – Input data, shape (batch, n_nodes).

  • D_new (int) – New number of basis functions.

Example

>>> key = jax.random.key(42)
>>> grid_range = (-2.0, 2.0)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=grid_range[0], maxval=grid_range[-1])
>>>
>>> grid = RBFGrid(n_nodes = 2, D = 4, grid_range = grid_range, grid_e = 1.0)
>>> grid.update(x=x_batch, D_new=8)

KAN Models

class jaxkan.models.KAN.KAN(layer_dims, layer_type='base', required_parameters=None, seed=42)[source]

Bases: Module

KAN class, corresponding to a network of KAN Layers.

layers

List of KAN layer instances.

Type:

nnx.List

__init__(layer_dims, layer_type='base', required_parameters=None, seed=42)[source]

Initializes a KAN model.

Parameters:
  • layer_dims (List[int]) – Defines the network in terms of nodes. E.g. [4,5,1] is a network with 2 layers: one with n_in=4 and n_out=5 and one with n_in=5 and n_out = 1.

  • layer_type (str) – Type of layer to use (e.g., ‘base’).

  • required_parameters (dict) – Dictionary containing parameters required for the chosen layer type.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> req_params = {'k': 3, 'G': 5}
>>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42)
update_grids(x, G_new)[source]

Performs the grid update for each layer of the KAN architecture.

Parameters:
  • x (jnp.array) – Inputs for the first layer, shape (batch, self.layers[0]).

  • G_new (int) – Size of the new grid (in terms of intervals).

Example

>>> req_params = {'k': 3, 'G': 5}
>>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> model.update_grids(x=x_batch, G_new=10)
__call__(x)[source]

Equivalent to the network’s forward pass.

Parameters:

x (jnp.array) – Inputs for the first layer, shape (batch, self.layers[0]).

Returns:

Network output, shape (batch, self.layers[-1]).

Return type:

x (jnp.array)

Example

>>> req_params = {'k': 3, 'G': 5}
>>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = model(x_batch)
class jaxkan.models.RGAKAN.RGABlock(n_in, n_out, n_hidden, D=5, flavor='exact', init_scheme=None, alpha=0.0, beta=1.0, seed=42)[source]

Bases: Module

Residual-Gated Adaptive Block for RGAKAN architecture.

InputLayer

First Chebyshev layer in the block.

Type:

Layer

OutputLayer

Second Chebyshev layer in the block.

Type:

Layer

alpha

Trainable residual connection weight for the output.

Type:

nnx.Param

beta

Trainable residual connection weight for the hidden state.

Type:

nnx.Param

__init__(n_in, n_out, n_hidden, D=5, flavor='exact', init_scheme=None, alpha=0.0, beta=1.0, seed=42)[source]

Initializes an RGABlock.

Parameters:
  • n_in (int) – Input dimension.

  • n_out (int) – Output dimension.

  • n_hidden (int) – Hidden layer dimension.

  • D (int) – Degree of Chebyshev polynomials.

  • flavor (str) – Type of Chebyshev layer (‘exact’ or other variants).

  • init_scheme (dict, optional) – Initialization scheme for layer weights.

  • alpha (float) – Initial value for output residual connection weight.

  • beta (float) – Initial value for hidden residual connection weight.

  • seed (int) – Random seed for reproducible initialization.

Example

>>> block = RGABlock(n_in=64, n_out=64, n_hidden=64, D=5, flavor='exact', seed=42)
__call__(x, u, v)[source]

Forward pass through the RGA block.

Parameters:
  • x (jnp.array) – Input array, shape (batch, n_in).

  • u (jnp.array) – First gating signal, shape (batch, n_hidden).

  • v (jnp.array) – Second gating signal, shape (batch, n_hidden).

Returns:

Output after applying gated attention and residual connections, shape (batch, n_out).

Return type:

x (jnp.array)

Example

>>> block = RGABlock(n_in=64, n_out=64, n_hidden=64, D=5, seed=42)
>>> x = jnp.ones((32, 64))
>>> u = jnp.ones((32, 64))
>>> v = jnp.zeros((32, 64))
>>> output = block(x, u, v)  # Shape: (32, 64)
class jaxkan.models.RGAKAN.RGAKAN(n_in, n_out, n_hidden, num_blocks, flavor='exact', D=5, init_scheme=None, alpha=0.0, beta=1.0, ref=None, period_axes=None, rff_std=None, sine_D=None, seed=42)[source]

Bases: Module

Residual-Gated Adaptive Kolmogorov-Arnold Network (RGAKAN). See paper “Training Deep Physics-Informed Kolmogorov-Arnold Networks”. https://www.sciencedirect.com/science/article/pii/S0045782526000356

pi_init

Whether physics-informed initialization is enabled.

Type:

bool

n_hidden

Hidden layer dimension.

Type:

int

D

Degree of Chebyshev polynomials.

Type:

int

PE

Periodic embedder if period_axes is provided.

Type:

Union[PeriodEmbedder, None]

FE

Random Fourier Features embedder if rff_std is provided.

Type:

Union[RFFEmbedder, None]

SineBasis

Sine basis layer if sine_D is provided.

Type:

Union[SineLayer, None]

U

First gating network.

Type:

Layer

V

Second gating network.

Type:

Layer

blocks

List of RGABlock instances.

Type:

nnx.List

OutBasis

Physics-informed output coefficients if pi_init is True.

Type:

Union[nnx.Param, None]

OutLayer

Standard output layer if pi_init is False.

Type:

Union[Layer, None]

__init__(n_in, n_out, n_hidden, num_blocks, flavor='exact', D=5, init_scheme=None, alpha=0.0, beta=1.0, ref=None, period_axes=None, rff_std=None, sine_D=None, seed=42)[source]

Initializes an RGAKAN model.

Parameters:
  • n_in (int) – Input dimension (before any embeddings).

  • n_out (int) – Output dimension.

  • n_hidden (int) – Hidden layer dimension.

  • num_blocks (int) – Number of RGA blocks to stack.

  • flavor (str) – Type of Chebyshev layer (‘exact’ or other variants).

  • D (int) – Degree of Chebyshev polynomials.

  • init_scheme (dict, optional) – Initialization scheme for layer weights.

  • alpha (float) – Initial value for output residual connection weights in blocks.

  • beta (float) – Initial value for hidden residual connection weights in blocks.

  • ref (dict, optional) – Reference data for physics-informed initialization. Must contain ‘t’, ‘x’, and ‘usol’.

  • period_axes (dict, optional) – Dictionary for periodic embedding: {axis: (period, trainable)}.

  • rff_std (float, optional) – Standard deviation for Random Fourier Features embedding.

  • sine_D (int, optional) – Degree for sine basis layer.

  • seed (int) – Random seed for reproducible initialization.

Example

>>> # Standard RGAKAN
>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, D=5, seed=42)
>>>
>>> # RGAKAN with periodic embedding
>>> period_axes = {0: (2.0 * jnp.pi, False)}
>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4,
...                period_axes=period_axes, seed=42)
__call__(x)[source]

Forward pass through the RGAKAN model.

Parameters:

x (jnp.array) – Input array, shape (batch, n_in).

Returns:

Model output, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> model = RGAKAN(n_in=2, n_out=1, n_hidden=64, num_blocks=4, seed=42)
>>> x = jnp.ones((32, 2))
>>> y = model(x)  # Shape: (32, 1)
class jaxkan.models.ActNet.ActLayer(n_in=3, n_out=4, N=5, train_basis=True, seed=42)[source]

Bases: Module

ActLayer implementation based on: “Deep Learning Alternatives of the Kolmogorov Superposition Theorem” by Leonardo Ferreira Guilhoto and Paris Perdikaris (arXiv:2410.01990)

Forward pass: ActLayer(x) = S(Λ ⊙ (β @ B(x)))

Where: - B(x) is the basis expansion matrix with B(x)_{ij} = b_i(x_j), shape (N, d) - β ∈ R^{m × N} are the basis coefficients - Λ ∈ R^{m × d} are the mixing weights - S is the row-sum function - ⊙ is the Hadamard (element-wise) product

The k-th output is: (ActLayer(x))_k = Σ_i λ_{ki} Σ_j β_{kj} b_j(x_i)

rngs

Random number generator state.

Type:

nnx.Rngs

beta

Basis coefficients, shape (n_out, N).

Type:

nnx.Param

Lambda

Mixing weights, shape (n_out, n_in).

Type:

nnx.Param

omega

Frequency parameters for basis functions (trainable or fixed).

Type:

Union[nnx.Param, jnp.array]

phase

Phase parameters for basis functions (trainable or fixed).

Type:

Union[nnx.Param, jnp.array]

__init__(n_in=3, n_out=4, N=5, train_basis=True, seed=42)[source]

Initializes an ActLayer instance.

Parameters:
  • n_in (int) – Number of layer’s incoming nodes.

  • n_out (int) – Number of layer’s outgoing nodes.

  • N (int) – Number of basis functions (paper recommends N=4).

  • train_basis (bool) – Whether the basis function parameters (omega and phase) are trainable.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> layer = ActLayer(n_in=2, n_out=5, N=4, train_basis=True, seed=42)
basis(x)[source]

Compute normalized sinusoidal basis functions.

Paper Equation 11: b_i(t) = (sin(ω_i * t + p_i) - μ(ω_i, p_i)) / σ(ω_i, p_i)

Where μ and σ are computed assuming x ~ N(0, 1): μ(ω, p) = exp(-ω²/2) * sin(p) σ(ω, p) = sqrt(1/2 - 1/2 * exp(-2ω²) * cos(2p) - μ²)

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Basis expansion matrix, shape (batch, N, n_in), where B[b, j, i] = b_j(x_{b,i}).

Return type:

B (jnp.array)

Example

>>> layer = ActLayer(n_in=2, n_out=5, N=4, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> B = layer.basis(x_batch)
__call__(x)[source]

The layer’s forward pass.

Paper Equation 6: ActLayer(x) = S(Λ ⊙ (β @ B(x))) Paper Equation 9: (ActLayer(x))_k = Σ_i λ_{ki} Σ_j β_{kj} b_j(x_i)

Parameters:

x (jnp.array) – Inputs, shape (batch, n_in).

Returns:

Output of the forward pass, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> layer = ActLayer(n_in=2, n_out=5, N=4, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = layer(x_batch)
class jaxkan.models.ActNet.ActNet(layer_dims, N=4, add_bias=True, omega0=1.0, use_projections=False, train_basis=True, seed=42)[source]

Bases: Module

ActNet architecture based on: “Deep Learning Alternatives of the Kolmogorov Superposition Theorem” by Leonardo Ferreira Guilhoto and Paris Perdikaris (arXiv:2410.01990)

rngs

Random number generator state.

Type:

nnx.Rngs

add_bias

Whether to add learnable bias after each ActLayer.

Type:

bool

omega0

Frequency multiplier for input.

Type:

float

use_projections

Whether to use input/output linear projections.

Type:

bool

input_proj

Input projection layer if use_projections is True.

Type:

nnx.Linear, optional

layers

List of ActLayer instances.

Type:

nnx.List

output_proj

Output projection layer if use_projections is True.

Type:

nnx.Linear, optional

biases

List of bias parameters if add_bias is True.

Type:

nnx.List, optional

__init__(layer_dims, N=4, add_bias=True, omega0=1.0, use_projections=False, train_basis=True, seed=42)[source]

Initializes an ActNet model.

Parameters:
  • layer_dims (List[int]) – Defines the network in terms of nodes. E.g. [2,5,1] is a network with 2 layers.

  • N (int) – Number of basis functions per ActLayer (paper recommends N=4).

  • add_bias (bool) – Whether to add learnable bias after each ActLayer.

  • omega0 (float) – Frequency multiplier for input (paper’s Appendix D.1).

  • use_projections (bool) – Whether to use input/output linear projections.

  • train_basis (bool) – Whether the basis function parameters (omega and phase) are trainable.

  • seed (int) – Random key selection for initializations wherever necessary.

Example

>>> model = ActNet(layer_dims=[2, 5, 1], N=4, add_bias=True, train_basis=True, seed=42)
__call__(x)[source]

Equivalent to the network’s forward pass.

Parameters:

x (jnp.array) – Inputs for the first layer, shape (batch, layer_dims[0]).

Returns:

Network output, shape (batch, layer_dims[-1]).

Return type:

x (jnp.array)

Example

>>> model = ActNet(layer_dims=[2, 5, 1], N=4, add_bias=True, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = model(x_batch)
class jaxkan.models.KKAN.ChebyshevEmbedding(D_e)[source]

Bases: Module

Chebyshev polynomial embedding layer with trainable coefficients.

For an input x, computes: [c_0 * T_0(x), c_1 * T_1(x), …, c_{D_e} * T_{D_e}(x)] where T_n are Chebyshev polynomials of the first kind and c_n are trainable parameters.

D_e

Degree of Chebyshev polynomial expansion.

Type:

int

use_exact

Whether to use exact Chebyshev polynomials from Cb dictionary.

Type:

bool

C

Trainable coefficients for the Chebyshev polynomials.

Type:

nnx.Param

__init__(D_e)[source]

Initializes a ChebyshevEmbedding layer.

Parameters:

D_e (int) – Degree of Chebyshev polynomial expansion.

Example

>>> embedding = ChebyshevEmbedding(D_e=5)
__call__(x)[source]

Applies Chebyshev embedding to input.

Parameters:

x (jnp.array) – Input tensor, shape (batch, n_features) or (batch,).

Returns:

Chebyshev embedded tensor, shape (batch, n_features * (D_e + 1)).

Return type:

embedded (jnp.array)

Example

>>> embedding = ChebyshevEmbedding(D_e=5)
>>> x = jax.random.uniform(jax.random.key(0), (100, 1))
>>> y = embedding(x)  # shape: (100, 6)
class jaxkan.models.KKAN.InnerBlock(D_e=7, H=32, L=4, m=32, activation='tanh', seed=42)[source]

Bases: Module

Inner Block for KKAN architecture.

The Inner Block processes a single input dimension x_p through: 1. Chebyshev embedding: x_p -> [c_0*T_0(x_p), …, c_{D_e}*T_{D_e}(x_p)] -> (D_e+1)-dim 2. Input Dense layer: (D_e+1)-dim -> H-dim 3. L hidden Dense layers: H-dim -> H-dim (each followed by activation) 4. Output Chebyshev embedding: H-dim -> H*(D_e+1)-dim (flattened) 5. Final Dense layer: H*(D_e+1)-dim -> m-dim

activation

Activation function.

Type:

callable

input_embedding

Chebyshev embedding layer for input.

Type:

ChebyshevEmbedding

input_layer

Dense layer after input embedding.

Type:

DenseLayer

hidden_layers

List of hidden Dense layers.

Type:

nnx.List

output_embedding

Chebyshev embedding layer for output.

Type:

ChebyshevEmbedding

output_layer

Final Dense layer.

Type:

Dense

__init__(D_e=7, H=32, L=4, m=32, activation='tanh', seed=42)[source]

Initializes an InnerBlock.

Parameters:
  • D_e (int) – Degree of Chebyshev polynomial expansion.

  • H (int) – Hidden dimension for MLP layers.

  • L (int) – Number of hidden layers.

  • m (int) – Output dimension.

  • activation (str) – Activation function.

  • seed (int) – Random seed.

Example

>>> inner_block = InnerBlock(D_e=5, H=32, L=2, m=10, activation='tanh', seed=42)
__call__(x_p)[source]

Forward pass through the inner block for a single input component.

Parameters:

x_p (jnp.array) – Single input component, shape (batch, 1).

Returns:

Output tensor, shape (batch, m).

Return type:

out (jnp.array)

Example

>>> inner_block = InnerBlock(D_e=5, H=32, L=2, m=10, seed=42)
>>> x_p = jax.random.uniform(jax.random.key(0), (100, 1))
>>> y = inner_block(x_p)  # shape: (100, 10)
class jaxkan.models.KKAN.OuterBlock(m, n_out, layer_type='sine', layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]

Bases: Module

Outer Block for KKAN architecture.

This is a wrapper around existing KAN layers from jaxkan.layers. It applies the selected KAN layer to map from m dimensions to n_out dimensions.

layer

The underlying KAN layer instance.

Type:

nnx.Module

__init__(m, n_out, layer_type='sine', layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]

Initializes an OuterBlock.

Parameters:
  • m (int) – Input dimension.

  • n_out (int) – Output dimension.

  • layer_type (str) – Type of KAN layer (‘chebyshev’, ‘legendre’, ‘rbf’, ‘sine’, ‘fourier’, etc.).

  • layer_params (dict, optional) – Additional parameters for the KAN layer (e.g., D, flavor, kernel).

  • seed (int) – Random seed.

Example

>>> outer_block = OuterBlock(m=10, n_out=1, layer_type='chebyshev',
...                               layer_params={'D': 5}, seed=42)
__call__(xi)[source]

Forward pass through the outer block.

Parameters:

xi (jnp.array) – Input from combination layer, shape (batch, m).

Returns:

Output tensor, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> outer_block = OuterBlock(m=10, n_out=1, layer_type='chebyshev', seed=42)
>>> xi = jax.random.uniform(jax.random.key(0), (100, 10))
>>> y = outer_block(xi)  # shape: (100, 1)
class jaxkan.models.KKAN.KKAN(n_in, n_out, m=32, D_e=7, H=32, L=4, activation='tanh', outer_layer_type='sine', outer_layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]

Bases: Module

KKAN architecture based on: “KKANs: Kůrková-Kolmogorov-Arnold Networks and Their Learning Dynamics” by Juan Diego Toscano, Li-Lian Wang, and George Em Karniadakis

n_in

Input dimension (d).

Type:

int

inner_blocks

List of InnerBlock modules, one for each input dimension.

Type:

nnx.List

outer_block

The outer block that produces the final output.

Type:

OuterBlock

__init__(n_in, n_out, m=32, D_e=7, H=32, L=4, activation='tanh', outer_layer_type='sine', outer_layer_params={'D': 7, 'init_scheme': {'type': 'glorot_fine'}}, seed=42)[source]

Initializes a KKAN model.

Parameters:
  • n_in (int) – Input dimension.

  • n_out (int) – Output dimension.

  • m (int) – Intermediate dimension.

  • D_e (int) – Degree of Chebyshev expansion in inner blocks.

  • H (int) – Hidden dimension for inner block MLP.

  • L (int) – Number of hidden layers in inner block MLP.

  • activation (str) – Activation function (‘tanh’, ‘relu’, ‘silu’, ‘gelu’).

  • outer_layer_type (str) – Type of KAN layer for outer block (‘chebyshev’, ‘legendre’, ‘rbf’, ‘sine’, etc.).

  • outer_layer_params (dict, optional) – Additional parameters for outer block layer (e.g., {‘D’: 5, ‘flavor’: ‘exact’}).

  • seed (int) – Random seed.

Example

>>> model = KKAN(n_in=2, n_out=1, m=10, D_e=5, H=32, L=2,
...              outer_layer_type='chebyshev', seed=42)
__call__(x)[source]

Forward pass through the KKAN model.

Parameters:

x (jnp.array) – Input tensor, shape (batch, n_in).

Returns:

Output tensor, shape (batch, n_out).

Return type:

y (jnp.array)

Example

>>> model = KKAN(n_in=2, n_out=1, m=10, D_e=5, H=32, L=2, seed=42)
>>>
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> output = model(x_batch)
jaxkan.models.utils.get_activation(activation='tanh')[source]

Returns the corresponding activation function based on user input.

Parameters:

activation (str) – Name of the activation function. Options include: - ‘celu’: Continuously Differentiable ELU - ‘elu’: Exponential Linear Unit - ‘gelu’: Gaussian Error Linear Unit - ‘hard_sigmoid’: Hard sigmoid - ‘hard_silu’ / ‘hard_swish’: Hard SiLU - ‘hard_tanh’: Hard hyperbolic tangent - ‘identity’: Identity function (no activation) - ‘leaky_relu’: Leaky ReLU - ‘log_sigmoid’: Log-sigmoid function - ‘relu’: Rectified Linear Unit - ‘selu’: Scaled ELU - ‘sigmoid’: Sigmoid function - ‘silu’ / ‘swish’: Sigmoid Linear Unit - ‘soft_sign’: Soft sign function - ‘softplus’: Softplus function - ‘tanh’: Hyperbolic tangent (default)

Returns:

The activation function.

Return type:

callable

Example

>>> act_fn = get_activation('tanh')
>>> y = act_fn(x)
class jaxkan.models.utils.PeriodEmbedder(period_axes)[source]

Bases: Module

Periodic embedding module that applies trigonometric transformations to specified input axes.

axes

Dictionary storing period values for each axis. Values can be trainable (nnx.Param) or fixed.

Type:

nnx.Dict

__init__(period_axes)[source]

Initializes a PeriodEmbedder module.

Parameters:

period_axes (dict) – Dictionary mapping input axis indices to (period, trainable) tuples. The key is the axis index (int), and the value is a tuple where: - period (float): The period value for the trigonometric transformation. - trainable (bool): If True, period is stored as nnx.Param and can be optimized during training.

Example

>>> # Fixed period on axis 0, trainable period on axis 1
>>> period_axes = {0: (2.0 * jnp.pi, False), 1: (jnp.pi, True)}
>>> embedder = PeriodEmbedder(period_axes)
__call__(x)[source]

Applies periodic embedding to the input.

Parameters:

x (jnp.array) – Input array, shape (batch, n_in).

Returns:

Embedded output. For each axis with a period, the original feature is replaced by cos(period * x) and sin(period * x). Non-periodic axes are passed through unchanged. Shape (batch, n_out) where n_out depends on the number of periodic axes.

Return type:

y (jnp.array)

Example

>>> period_axes = {1: (jnp.pi, False)}
>>> embedder = PeriodEmbedder(period_axes)
>>> x = jnp.array([[1.0, 0.5], [2.0, 1.0]])
>>> y = embedder(x)  # Shape: (2, 3) - axis 0 unchanged, axis 1 → [cos, sin]
class jaxkan.models.utils.RFFEmbedder(std=1.0, n_in=1, embed_dim=256, seed=42)[source]

Bases: Module

Random Fourier Features (RFF) embedding module for nonlinear feature transformation.

B

Random projection matrix, shape (n_in, embed_dim//2).

Type:

nnx.Param

__init__(std=1.0, n_in=1, embed_dim=256, seed=42)[source]

Initializes a RFFEmbedder module.

Parameters:
  • std (float) – Standard deviation for the normal distribution used to initialize the random projection matrix.

  • n_in (int) – Input dimension.

  • embed_dim (int) – Output embedding dimension. Must be even (actual dimension used is embed_dim//2 for the random matrix).

  • seed (int) – Random seed for reproducible initialization.

Example

>>> embedder = RFFEmbedder(std=1.0, n_in=2, embed_dim=256, seed=42)
__call__(x)[source]

Applies Random Fourier Features transformation to the input.

Parameters:

x (jnp.array) – Input array, shape (batch, n_in).

Returns:

Embedded output using random Fourier features: [cos(xB), sin(xB)]. Shape (batch, embed_dim).

Return type:

y (jnp.array)

Example

>>> embedder = RFFEmbedder(std=1.0, n_in=2, embed_dim=256, seed=42)
>>> x = jnp.array([[1.0, 0.5], [2.0, 1.0]])
>>> y = embedder(x)  # Shape: (2, 256)
jaxkan.models.utils.count_params(model)[source]

Count the total number of trainable parameters in a model.

Parameters:

model (nnx.Module) – Flax model instance.

Returns:

Total number of trainable parameters in the model.

Return type:

total_params (int)

Example

>>> model = KAN([2,8,8,1], 'spline', {'k': 4, 'G': 3}, 42)
>>> num_params = count_params(model)
>>> print(f"Model has {num_params} parameters")
jaxkan.models.utils.get_frob(model, x)[source]

Compute the squared Frobenius norm of the model’s gradient at a given input point.

Parameters:
  • model (nnx.Module) – Flax model instance.

  • x (jnp.array) – Input point, shape (d,) or (1, d).

Returns:

Squared Frobenius norm of the gradient ||∇f(x)||²_F.

Return type:

fro_sq (float)

Example

>>> model = KAN([2,8,1], 'spline', {}, 42)
>>> x = jnp.array([0.5, 0.3])
>>> frob_norm_sq = get_frob(model, x)
jaxkan.models.utils.batched_frob(*args, **kwargs)
jaxkan.models.utils.get_complexity(model, pde_collocs, bc_collocs=None)[source]

Compute model complexity as the average squared Frobenius norm of gradients over collocation points.

Parameters:
  • model (nnx.Module) – Flax model instance.

  • pde_collocs (jnp.array) – Collocation points for PDE/equation domain, shape (N, d).

  • bc_collocs (jnp.array, optional) – Initial/boundary condition collocation points, shape (M, d). If None, only use collocs.

Returns:

Average squared Frobenius norm of gradients: mean(||∇f(x)||²_F).

Return type:

complexity (float)

Example

>>> model = KAN([2,8,1], 'spline', {}, 42)
>>> pde_collocs = jnp.array([[0.5, 0.3], [0.2, 0.7]])
>>> bc_collocs = jnp.array([[0.0, 0.5]])
>>> complexity = get_complexity(model, pde_collocs, bc_collocs)
jaxkan.models.utils.get_adam(learning_rate=0.001, schedule_type=None, decay_steps=5000, decay_rate=0.9, warmup_steps=0, staircase=False, b1=0.9, b2=0.999, eps=1e-08, **schedule_kwargs)[source]

Create an Adam optimizer with optional learning rate scheduling and warmup.

Parameters:
  • learning_rate (float) – Base learning rate. Default is 1e-3.

  • schedule_type (str, optional) – Type of learning rate schedule. Options: - None: Constant learning rate (default) - ‘exponential’: Exponential decay schedule - ‘cosine’: Cosine annealing schedule - ‘polynomial’: Polynomial decay schedule - ‘piecewise_constant’: Piecewise constant schedule (requires ‘boundaries’ and ‘values’ in schedule_kwargs)

  • decay_steps (int) – Number of steps for the learning rate decay schedule. Default is 5000. Used for exponential, cosine, and polynomial schedules.

  • decay_rate (float) – Decay rate for exponential schedule. Default is 0.9. For polynomial schedule, this is the ‘power’ parameter.

  • warmup_steps (int) – Number of warmup steps with linear learning rate increase from 0 to learning_rate. Default is 0 (no warmup).

  • staircase (bool) – If True, decay the learning rate at discrete intervals (staircase function). Default is False (smooth decay).

  • b1 (float) – Exponential decay rate for first moment. Default is 0.9.

  • b2 (float) – Exponential decay rate for second moment. Default is 0.999.

  • eps (float) – Small constant for numerical stability. Default is 1e-8.

  • **schedule_kwargs

    Additional keyword arguments for specific schedules. For piecewise_constant schedule:

    • boundaries (list): List of step boundaries

    • values (list): List of learning rate values (must be len(boundaries) + 1)

Returns:

Configured Adam optimizer with learning rate schedule.

Return type:

optax.GradientTransformation

Example

>>> # Adam with exponential decay and warmup
>>> optimizer = get_adam(
...     learning_rate=1e-3,
...     schedule_type='exponential',
...     decay_steps=5000,
...     decay_rate=0.9,
...     warmup_steps=1000,
...     b1=0.9,
...     b2=0.999
... )
>>> # Adam with cosine annealing
>>> optimizer = get_adam(
...     learning_rate=1e-3,
...     schedule_type='cosine',
...     decay_steps=10000,
...     warmup_steps=500
... )
>>> # Adam with constant learning rate
>>> optimizer = get_adam(learning_rate=1e-3)
jaxkan.models.utils.get_lbfgs(learning_rate=None, memory_size=10, scale_init_precond=True, linesearch=None)[source]

Create an L-BFGS optimizer.

Note: L-BFGS requires special handling when used with Flax NNX. You must pass value, value_fn, and model to the optimizer’s update method. The value_fn should be a function that takes the model and returns the loss value.

Parameters:
  • learning_rate (float, optional) – Initial learning rate. If None, the optimizer uses its own line search to determine the step size. Default is None.

  • memory_size (int) – Number of past updates to keep in memory to approximate the Hessian inverse. Larger values require more memory but may lead to better convergence. Default is 10.

  • scale_init_precond (bool) – Whether to use a scaled identity as the initial preconditioner. Default is True.

  • linesearch (optax.GradientTransformation, optional) – Custom line search transformation. If None, uses the default zoom line search. Default is None.

Returns:

Configured L-BFGS optimizer.

Return type:

optax.GradientTransformationExtraArgs

PIKAN Training

jaxkan.pikan.pde.get_ac_res(D=0.0001, c=5.0)[source]

Returns the Allen-Cahn equation residual function (2D).

Parameters:
  • D (float) – Diffusion coefficient. Default is 1e-4.

  • c (float) – Reaction coefficient. Default is 5.0.

Returns:

Residual function for the Allen-Cahn equation.

Return type:

ac_res (function)

Example

>>> # Use default parameters
>>> ac_res = get_ac_res()
>>>
>>> # Use custom parameters
>>> ac_res = get_ac_res(D=0.001, c=10.0)
jaxkan.pikan.pde.get_diffusion_res(D=0.25)[source]

Returns the diffusion equation residual function (2D).

Parameters:

D (float) – Diffusion coefficient. Default is 0.25.

Returns:

Residual function for the diffusion equation.

Return type:

diffusion_res (function)

Example

>>> # Use default parameters
>>> diffusion_res = get_diffusion_res()
>>>
>>> # Use custom parameters
>>> diffusion_res = get_diffusion_res(D=0.5)
jaxkan.pikan.pde.get_burgers_res(nu=0.003183098861837907)[source]

Returns the Burgers equation residual function (2D).

Parameters:

nu (float) – Viscosity coefficient. Default is 0.01/π.

Returns:

Residual function for the Burgers equation.

Return type:

burgers_res (function)

Example

>>> # Use default parameters
>>> burgers_res = get_burgers_res()
>>>
>>> # Use custom parameters
>>> burgers_res = get_burgers_res(nu=0.001)
jaxkan.pikan.pde.get_kdv_res(eta=1.0, mu=0.022)[source]

Returns the Korteweg-de Vries equation residual function (2D).

Parameters:
  • eta (float) – Nonlinearity coefficient. Default is 1.0.

  • mu (float) – Dispersion coefficient. Default is 0.022.

Returns:

Residual function for the Korteweg-de Vries equation.

Return type:

kdv_res (function)

Example

>>> # Use default parameters
>>> kdv_res = get_kdv_res()
>>>
>>> # Use custom parameters
>>> kdv_res = get_kdv_res(eta=2.0, mu=0.01)
jaxkan.pikan.pde.get_sg_res(D=1.0)[source]

Returns the sine-Gordon equation residual function (2D).

Parameters:

D (float) – Wave speed coefficient. Default is 1.0.

Returns:

Residual function for the sine-Gordon equation.

Return type:

sg_res (function)

Example

>>> # Use default parameters
>>> sg_res = get_sg_res()
>>>
>>> # Use custom parameters
>>> sg_res = get_sg_res(D=2.0)
jaxkan.pikan.pde.get_advection_res(c=20.0)[source]

Returns the advection equation residual function (2D).

Parameters:

c (float) – Wave speed coefficient. Default is 20.0.

Returns:

Residual function for the advection equation.

Return type:

advection_res (function)

Example

>>> # Use default parameters
>>> advection_res = get_advection_res()
>>>
>>> # Use custom parameters
>>> advection_res = get_advection_res(c=10.0)
jaxkan.pikan.pde.get_helmholtz_res(a1=1.0, a2=4.0, k=1.0)[source]

Returns the Helmholtz equation residual function (2D).

Parameters:
  • a1 (float) – First frequency parameter for source term. Default is 1.0.

  • a2 (float) – Second frequency parameter for source term. Default is 4.0.

  • k (float) – Wave number. Default is 1.0.

Returns:

Residual function for the Helmholtz equation.

Return type:

helmholtz_res (function)

Example

>>> # Use default parameters
>>> helmholtz_res = get_helmholtz_res()
>>>
>>> # Use custom parameters
>>> helmholtz_res = get_helmholtz_res(a1=2.0, a2=3.0, k=2.0)
jaxkan.pikan.pde.get_poisson_res(a1=4.0, a2=4.0)[source]

Returns the Poisson equation residual function (2D).

Parameters:
  • a1 (float) – First frequency parameter for source term. Default is 4.0.

  • a2 (float) – Second frequency parameter for source term. Default is 4.0.

Returns:

Residual function for the Poisson equation.

Return type:

poisson_res (function)

Example

>>> # Use default parameters
>>> poisson_res = get_poisson_res()
>>>
>>> # Use custom parameters
>>> poisson_res = get_poisson_res(a1=2.0, a2=3.0)
jaxkan.pikan.pde.get_wave_res(c=4.0)[source]

Returns the wave equation residual function (2D).

Parameters:

c (float) – Wave speed coefficient. Default is 4.0.

Returns:

Residual function for the wave equation.

Return type:

wave_res (function)

Example

>>> # Use default parameters
>>> wave_res = get_wave_res()
>>>
>>> # Use custom parameters
>>> wave_res = get_wave_res(c=2.0)
jaxkan.pikan.pde.get_ks_res(alpha=6.25, beta=0.390625, gamma=0.00152587890625)[source]

Returns the Kuramoto-Sivashinsky equation residual function (2D).

Parameters:
  • alpha (float) – Nonlinearity coefficient. Default is 100/16.

  • beta (float) – Second-order dispersion coefficient. Default is 100/(16²).

  • gamma (float) – Fourth-order dispersion coefficient. Default is 100/(16⁴).

Returns:

Residual function for the Kuramoto-Sivashinsky equation.

Return type:

ks_res (function)

Example

>>> # Use default parameters
>>> ks_res = get_ks_res()
>>>
>>> # Use custom parameters
>>> ks_res = get_ks_res(alpha=6.25, beta=0.390625, gamma=0.0009765625)
jaxkan.pikan.sampling.get_collocs_grid(ranges)[source]

Generate grid-based collocation points across arbitrary dimensions.

Parameters:

ranges (list[tuple[float, float, int]]) – List of tuples where each tuple contains (lower_bound, upper_bound, sample_size) for each dimension.

Returns:

Array of shape (total_points, n_dims) containing all grid points.

Return type:

jnp.array

Example

>>> # 2D grid: x in [0, 1] with 10 points, t in [0, 2] with 20 points
>>> collocs = pde_collocs_grid(ranges=[(0.0, 1.0, 10), (0.0, 2.0, 20)])
jaxkan.pikan.sampling.get_collocs_random(ranges, total_points, seed=42)[source]

Generate random collocation points across arbitrary dimensions.

Parameters:
  • ranges (list[tuple[float, float]]) – List of tuples where each tuple contains (lower_bound, upper_bound) for each dimension.

  • total_points (int) – Total number of random points to generate.

  • seed (int) – Random seed for reproducibility.

Returns:

Array of shape (total_points, n_dims) containing randomly sampled points.

Return type:

jnp.array

Example

>>> # 2D random samples: 100 points in [0, 1] x [0, 2]
>>> collocs = pde_collocs_random(ranges=[(0.0, 1.0), (0.0, 2.0)], total_points=100, seed=42)
jaxkan.pikan.sampling.get_collocs_sobol(ranges, total_points, seed=42)[source]

Generate Sobol quasi-random collocation points across arbitrary dimensions.

Parameters:
  • ranges (list[tuple[float, float]]) – List of tuples where each tuple contains (lower_bound, upper_bound) for each dimension.

  • total_points (int) – Total number of Sobol points to generate. Should be a power of 2 for optimal coverage.

  • seed (int) – Random seed for reproducibility. Default is 42.

Returns:

Array of shape (total_points, n_dims) containing Sobol sampled points.

Return type:

jnp.array

Example

>>> # 2D Sobol samples: 1024 points in [0, 1] x [0, 2]
>>> collocs = pde_collocs_sobol(ranges=[(0.0, 1.0), (0.0, 2.0)], total_points=1024, seed=42)
jaxkan.pikan.adaptive.get_colloc_indices(collocs_pool, batch_size, px, seed)[source]

Sample collocation point indices from a pool based on probability weights and sort by first coordinate (time).

Parameters:
  • collocs_pool (jnp.array) – Pool of collocation points, shape (N, n_dims).

  • batch_size (int) – Number of points to sample.

  • px (jnp.array) – Probability weights for sampling, shape (N,).

  • seed (int) – Random seed for reproducibility.

Returns:

Indices of sampled points sorted by first coordinate, shape (batch_size,).

Return type:

sorted_pool_indices (jnp.array)

Example

>>> collocs_pool = jnp.array([[0.5, 0.1], [0.2, 0.3], [0.8, 0.7]])
>>> px = jnp.array([0.5, 0.3, 0.2])
>>> indices = get_colloc_indices(collocs_pool, batch_size=2, px=px, seed=42)
jaxkan.pikan.adaptive.update_rba_weights(residuals, weights, gamma, eta, eps=1e-12)[source]

Update residual-based attention weights from the current residual magnitudes.

Parameters:
  • residuals (jnp.array) – Residual values, typically shape (N, 1).

  • weights (jnp.array) – Current RBA weights with the same shape as residuals.

  • gamma (float) – Exponential moving average coefficient.

  • eta (float) – Residual scaling coefficient.

  • eps (float) – Numerical stabilizer for normalization.

Returns:

Updated RBA weights.

Return type:

jnp.array

jaxkan.pikan.adaptive.apply_rba_weights(residuals, weights)[source]

Apply RBA weights to residuals while keeping the weights outside the backward graph.

Parameters:
  • residuals (jnp.array) – Residual values.

  • weights (jnp.array) – RBA weights with matching shape.

Returns:

Weighted residuals.

Return type:

jnp.array

jaxkan.pikan.adaptive.get_causal_weights(losses, causal_matrix, causal_tol)[source]

Compute causal training weights and stop their gradients.

Parameters:
  • losses (jnp.array) – Chunk-wise loss values.

  • causal_matrix (jnp.array) – Upper-triangular causal coupling matrix.

  • causal_tol (float) – Exponential weighting coefficient.

Returns:

Stopped-gradient causal weights.

Return type:

jnp.array

jaxkan.pikan.adaptive.get_rad_probabilities(weighted_residuals, rad_a, rad_c, eps=1e-12)[source]

Convert weighted residuals into normalized RAD sampling probabilities.

Parameters:
  • weighted_residuals (jnp.array) – Residuals already multiplied by their current adaptive weights.

  • rad_a (float) – Density-control exponent.

  • rad_c (float) – Baseline offset added to the density.

  • eps (float) – Numerical stabilizer for normalization.

Returns:

Probability vector of shape (N,).

Return type:

jnp.array

jaxkan.pikan.adaptive.get_rad_indices(collocs_pool, residuals, old_indices, batch_weights, pool_weights, batch_size, rad_a, rad_c, seed, eps=1e-12)[source]

Update pool weights and draw a new RAD-resampled batch of collocation indices.

Parameters:
  • collocs_pool (jnp.array) – Full collocation pool, shape (N, d).

  • residuals (jnp.array) – Residual values over the full pool, shape (N, 1).

  • old_indices (jnp.array) – Previously active batch indices.

  • batch_weights (jnp.array) – Current adaptive weights for the active batch.

  • pool_weights (jnp.array) – Full pool of adaptive weights.

  • batch_size (int) – Number of indices to sample.

  • rad_a (float) – Density-control exponent.

  • rad_c (float) – Baseline offset added to the density.

  • seed (int) – Sampling seed.

  • eps (float) – Numerical stabilizer.

Returns:

Sampled indices, updated full-pool weights, and normalized probabilities.

Return type:

tuple[jnp.array, jnp.array, jnp.array]

jaxkan.pikan.adaptive.lr_anneal(grads, lambdas, grad_mixing)[source]

Perform the learning rate annealing algorithm introduced in “Understanding and mitigating gradient pathologies in physics-informed neural networks”

Parameters:
  • grads (tuple | list) – Tuple/list of gradient pytrees, one per loss term.

  • lambdas (tuple | list) – Tuple/list of current global loss weights, matching grads in order.

  • grad_mixing (float) – Exponential moving average coefficient.

Returns:

Updated global loss weights in the same logical order as the inputs.

Return type:

tuple

Example

>>> λs_new = lr_anneal((grads_E, grads_B, grads_aux), (1.0, 1.0, 1.0), 0.9)
jaxkan.pikan.utils.model_eval(model, coords, refsol)[source]

Compute the relative L2 error between model predictions and reference solution.

Parameters:
  • model (nnx.Module) – Flax model instance.

  • coords (jnp.array) – Input coordinates for evaluation, shape (N, n_dims).

  • refsol (jnp.array) – Reference solution values for comparison.

Returns:

Relative L2 error: ||prediction - reference|| / ||reference||

Return type:

l2err (float)

Example

>>> model = KAN([2,8,1], 'spline', {}, 42)
>>> coords = jnp.array([[0.5, 0.5], [0.1, 0.2]])
>>> refsol = jnp.array([[0.25], [0.02]])
>>> error = model_eval(model, coords, refsol)