Source code for jaxkan.models.KAN

from jax import numpy as jnp

from flax import nnx

from ..layers import get_layer

from typing import Union, List


[docs] class KAN(nnx.Module): """ KAN class, corresponding to a network of KAN Layers. Attributes: layers (nnx.List): List of KAN layer instances. """
[docs] def __init__(self, layer_dims: List[int], layer_type: str = "base", required_parameters: Union[None, dict] = None, seed: int = 42 ): """ Initializes a KAN model. Args: layer_dims (List[int]): Defines the network in terms of nodes. E.g. [4,5,1] is a network with 2 layers: one with n_in=4 and n_out=5 and one with n_in=5 and n_out = 1. layer_type (str): Type of layer to use (e.g., 'base'). required_parameters (dict): Dictionary containing parameters required for the chosen layer type. seed (int): Random key selection for initializations wherever necessary. Example: >>> req_params = {'k': 3, 'G': 5} >>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42) """ # Get the corresponding layer class based on layer_type LayerClass = get_layer(layer_type.lower()) if required_parameters is None: raise ValueError("required_parameters must be provided as a dictionary for the selected layer_type.") self.layers = nnx.List([ LayerClass( n_in=layer_dims[i], n_out=layer_dims[i + 1], **required_parameters, seed=seed ) for i in range(len(layer_dims) - 1) ])
[docs] def update_grids(self, x, G_new): """ Performs the grid update for each layer of the KAN architecture. Args: x (jnp.array): Inputs for the first layer, shape (batch, self.layers[0]). G_new (int): Size of the new grid (in terms of intervals). Example: >>> req_params = {'k': 3, 'G': 5} >>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> model.update_grids(x=x_batch, G_new=10) """ # Loop over each layer for i, layer in enumerate(self.layers): # Update the grid for the current layer layer.update_grid(x, G_new) # Perform a forward pass to get the input for the next layer x = layer(x)
[docs] def __call__(self, x): """ Equivalent to the network's forward pass. Args: x (jnp.array): Inputs for the first layer, shape (batch, self.layers[0]). Returns: x (jnp.array): Network output, shape (batch, self.layers[-1]). Example: >>> req_params = {'k': 3, 'G': 5} >>> model = KAN(layer_dims = [2,5,1], layer_type='base', required_parameters=req_params, seed=42) >>> >>> key = jax.random.key(42) >>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0) >>> >>> output = model(x_batch) """ # Pass through each layer of the KAN for i, layer in enumerate(self.layers): x = layer(x) return x