Tutorial 1 - Getting Started

In this first tutorial notebook we will be creating a simple KAN model and train it to perform function fitting.

[1]:
from jaxkan.models.KAN import KAN

import jax
import jax.numpy as jnp

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from flax import nnx
import optax

import matplotlib.pyplot as plt
import numpy as np

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
[ ]:

Data Generation

We begin with the generation of some mock data. Consider the function \(f(x, y) = x^2 + 2\exp(y)\). We will attempt to fit a KAN model on synthetic data generated from this function and then evaluate its performance.

[2]:
def f(x,y):
    return x**2 + 2*jnp.exp(y)

def generate_data(minval=-1, maxval=1, num_samples=1000, seed=42):
    key = jax.random.PRNGKey(seed)
    x_key, y_key = jax.random.split(key)

    x1 = jax.random.uniform(x_key, shape=(num_samples,), minval=minval, maxval=maxval)
    x2 = jax.random.uniform(y_key, shape=(num_samples,), minval=minval, maxval=maxval)

    y = f(x1, x2).reshape(-1, 1)
    X = jnp.stack([x1, x2], axis=1)

    return X, y
[3]:
seed = 42

X, y = generate_data(minval=-1, maxval=1, num_samples=1000, seed=seed)
[ ]:

Preprocessing

With the data at hand, we may perform the usual train/test splitting.

[4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

print("Training set size:", X_train.shape)
print("Test set size:", X_test.shape)
Training set size: (800, 2)
Test set size: (200, 2)

Note that the shapes of the objects are (batch, input_dim).

[ ]:

KAN Model

Selecting and initializing a jaxKAN model is straightforward. Here we focus on the most basic model, KAN, where one needs to define the following arguments, which are universal for all KAN Layer types:

  • layer_dims: The dimensions for each of the network’s layers’ input/output nodes.

  • layer_type: The type of KAN layer to use, which defaults to base (more details can be found in the next tutorial for other layer types).

  • seed: An integer for reproducibility.

Apart from these, there is also the required_parameters argument, which is simply a dictionary with keys that correspond to additional parameters of the specific chosen layer.

We will use a Chebyshev Layer for the present example. Note that not only other layer types, but also other models can be used instead of KAN (examples can be found in subsequent tutorials).

[5]:
# Initialize a KAN model
n_in = X_train.shape[1]
n_out = y_train.shape[1]
n_hidden = 6

layer_dims = [n_in, n_hidden, n_hidden, n_out]
req_params = {'D': 5, 'flavor': 'exact'}

model = KAN(layer_dims = layer_dims,
            layer_type = 'chebyshev',
            required_parameters = req_params,
            seed = 42
           )

print(model)
KAN( # Param: 283 (1.1 KB), RngState: 6 (36 B), Total: 289 (1.2 KB)
  layer_type='chebyshev',
  layers=List([
    ChebyshevLayer( # Param: 66 (264 B), RngState: 2 (12 B), Total: 68 (276 B)
      n_in=2,
      n_out=6,
      D=5,
      flavor='exact',
      residual=None,
      rngs=Rngs( # RngState: 2 (12 B)
        default=RngStream( # RngState: 2 (12 B)
          tag='default',
          key=RngKey( # 1 (8 B)
            value=Array((), dtype=key<fry>) overlaying:
            [ 0 42],
            tag='default'
          ),
          count=RngCount( # 1 (4 B)
            value=Array(1, dtype=uint32),
            tag='default'
          )
        )
      ),
      bias=Param( # 6 (24 B)
        value=Array(shape=(6,), dtype=dtype('float32'))
      ),
      c_ext=None,
      c_basis=Param( # 60 (240 B)
        value=Array(shape=(6, 2, 5), dtype=dtype('float32'))
      )
    ),
    ChebyshevLayer( # Param: 186 (744 B), RngState: 2 (12 B), Total: 188 (756 B)
      n_in=6,
      n_out=6,
      D=5,
      flavor='exact',
      residual=None,
      rngs=Rngs( # RngState: 2 (12 B)
        default=RngStream( # RngState: 2 (12 B)
          tag='default',
          key=RngKey( # 1 (8 B)
            value=Array((), dtype=key<fry>) overlaying:
            [ 0 42],
            tag='default'
          ),
          count=RngCount( # 1 (4 B)
            value=Array(1, dtype=uint32),
            tag='default'
          )
        )
      ),
      bias=Param( # 6 (24 B)
        value=Array(shape=(6,), dtype=dtype('float32'))
      ),
      c_ext=None,
      c_basis=Param( # 180 (720 B)
        value=Array(shape=(6, 6, 5), dtype=dtype('float32'))
      )
    ),
    ChebyshevLayer( # Param: 31 (124 B), RngState: 2 (12 B), Total: 33 (136 B)
      n_in=6,
      n_out=1,
      D=5,
      flavor='exact',
      residual=None,
      rngs=Rngs( # RngState: 2 (12 B)
        default=RngStream( # RngState: 2 (12 B)
          tag='default',
          key=RngKey( # 1 (8 B)
            value=Array((), dtype=key<fry>) overlaying:
            [ 0 42],
            tag='default'
          ),
          count=RngCount( # 1 (4 B)
            value=Array(1, dtype=uint32),
            tag='default'
          )
        )
      ),
      bias=Param( # 1 (4 B)
        value=Array([0.], dtype=float32)
      ),
      c_ext=None,
      c_basis=Param( # 30 (120 B)
        value=Array(shape=(1, 6, 5), dtype=dtype('float32'))
      )
    )
  ])
)
[ ]:

Training

To train the model on the data, we must first define an optimizer using the optax framework. For this example, we will define a simple Adam optimizer and wrap the optax optimizer and the model in a nnx.Optimizer object.

[7]:
opt_type = optax.adam(learning_rate=0.001)

optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)

We will then define the training step and decorate it with @nnx.jit to make it faster. The loss function will be a simple MSE Loss.

[10]:
# Define train loop
@nnx.jit
def train_step(model, optimizer, X_train, y_train):

    def loss_fn(model):
        residual = model(X_train) - y_train
        loss = jnp.mean((residual)**2)

        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)

    return loss

At this point we may begin training the model and keep some logs, e.g., the training loss per epoch.

[11]:
# Initialize train_losses
num_epochs = 2000
train_losses = jnp.zeros((num_epochs,))

for epoch in range(num_epochs):
    # Calculate the loss
    loss = train_step(model, optimizer, X_train, y_train)

    # Append the loss
    train_losses = train_losses.at[epoch].set(loss)
[ ]:

Evaluation

First, we may plot the training losses to see if the model is indeed trained, i.e. backpropagation works as it should.

[12]:
plt.figure(figsize=(7, 4))

plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='#25599c', markersize=1)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.yscale('log')

plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)

plt.show()
../_images/tutorials_Tutorial_1_-_Getting_Started_28_0.png

Indeed, the model is trained as it should. The final step is to evaluate its performance on the test set.

[13]:
y_pred = model(X_test)
mse = mean_squared_error(y_test, y_pred)

print(f"The MSE of the fit is {mse:.5f}")
The MSE of the fit is 0.00562
[14]:
plt.figure(figsize=(7, 4))
plt.scatter(y_test, y_pred, alpha=0.7, color='#a3630f', marker='x', label='Predicted vs Actual')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], color='#25599c', linestyle='--', label='Perfect Fit')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Model Predictions vs Ground Truth')
plt.legend()
plt.grid(alpha=0.3)
plt.show()
../_images/tutorials_Tutorial_1_-_Getting_Started_31_0.png
[ ]: