Tutorial 11 - Optimizer Utilities

In this tutorial we will explore the optimizer utility functions provided in jaxKAN. We will cover two main optimizers: Adam with various learning rate schedules, and L-BFGS which requires special handling in Flax NNX. These optimizers are particularly useful for training KANs and PIKANs where adaptive learning rates can significantly improve convergence.

[1]:
from jaxkan.models.KAN import KAN
from jaxkan.models.utils import get_adam, get_lbfgs

import jax
import jax.numpy as jnp

from sklearn.model_selection import train_test_split

from flax import nnx
import optax

import matplotlib.pyplot as plt
import numpy as np

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
[ ]:

Data Generation

We will use the same function fitting problem from Tutorial 1 to compare different optimizer configurations. Consider the function \(f(x, y) = x^2 + 2\exp(y)\), which we will fit using a KAN model with different optimization strategies.

[2]:
def f(x, y):
    return x**2 + 2*jnp.exp(y)

def generate_data(minval=-1, maxval=1, num_samples=1000, seed=42):
    key = jax.random.PRNGKey(seed)
    x_key, y_key = jax.random.split(key)

    x1 = jax.random.uniform(x_key, shape=(num_samples,), minval=minval, maxval=maxval)
    x2 = jax.random.uniform(y_key, shape=(num_samples,), minval=minval, maxval=maxval)

    y = f(x1, x2).reshape(-1, 1)
    X = jnp.stack([x1, x2], axis=1)

    return X, y

seed = 42

X, y = generate_data(minval=-1, maxval=1, num_samples=1000, seed=seed)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

print("Training set size:", X_train.shape)
print("Test set size:", X_test.shape)
Training set size: (800, 2)
Test set size: (200, 2)
[ ]:

Part 1: Adam Optimizer

The Adam optimizer is the most commonly used optimizer for training neural networks, including KANs. The get_adam function provides a convenient interface for creating Adam optimizers with various learning rate schedules and warmup strategies.

Experiment Setup for Adam

We define a function to encapsulate the training loop. This function will allow us to easily experiment with different Adam configurations.

[3]:
def run_adam_experiment(adam_config, num_epochs=2000, verbose=True):
    """
    Run a training experiment with Adam optimizer.

    Args:
        adam_config: Dictionary with Adam optimizer parameters
        num_epochs: Number of training epochs
        verbose: Whether to print progress

    Returns:
        train_losses: Array of training losses
        test_loss: Final test loss
    """
    # Initialize a KAN model
    n_in = X_train.shape[1]
    n_out = y_train.shape[1]
    n_hidden = 6

    layer_dims = [n_in, n_hidden, n_hidden, n_out]
    req_params = {'D': 5, 'flavor': 'exact'}

    model = KAN(layer_dims=layer_dims,
                layer_type='chebyshev',
                required_parameters=req_params,
                seed=42)

    # Get Adam optimizer using the utility function
    opt_type = get_adam(**adam_config)
    optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)

    # Define train step
    @nnx.jit
    def train_step(model, optimizer, X_train, y_train):
        def loss_fn(model):
            residual = model(X_train) - y_train
            loss = jnp.mean((residual)**2)
            return loss

        loss, grads = nnx.value_and_grad(loss_fn)(model)
        optimizer.update(model, grads)

        return loss

    # Training loop
    train_losses = jnp.zeros((num_epochs,))

    for epoch in range(num_epochs):
        loss = train_step(model, optimizer, X_train, y_train)
        train_losses = train_losses.at[epoch].set(loss)

        if verbose and (epoch + 1) % 500 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.6f}')

    # Evaluate on test set
    y_pred = model(X_test)
    test_loss = jnp.mean((y_pred - y_test)**2)

    if verbose:
        print(f'\nFinal Test Loss: {test_loss:.6f}')

    return train_losses, test_loss
[ ]:

Constant Learning Rate

We begin with the simplest case: a constant learning rate. This is the default behavior when no schedule is specified.

[4]:
config_constant = {
    'learning_rate': 1e-3
}

train_losses_constant, test_loss_constant = run_adam_experiment(config_constant)
Epoch [500/2000], Loss: 0.064715
Epoch [1000/2000], Loss: 0.016274
Epoch [1500/2000], Loss: 0.007237
Epoch [2000/2000], Loss: 0.003581

Final Test Loss: 0.005624
[ ]:

Exponential Decay Schedule

A common strategy is to use exponential decay, where the learning rate decreases exponentially over time. This can help the model converge to a better solution by taking smaller steps as training progresses.

The learning rate at step \(t\) is given by:

\[\text{lr}(t) = \text{lr}_0 \cdot \gamma^{t / T}\]

where \(\text{lr}_0\) is the initial learning rate, \(\gamma\) is the decay rate, and \(T\) is the number of decay steps.

[5]:
config_exp_decay = {
    'learning_rate': 1e-3,
    'schedule_type': 'exponential',
    'decay_steps': 1000,
    'decay_rate': 0.9
}

train_losses_exp, test_loss_exp = run_adam_experiment(config_exp_decay)
Epoch [500/2000], Loss: 0.068539
Epoch [1000/2000], Loss: 0.017559
Epoch [1500/2000], Loss: 0.008215
Epoch [2000/2000], Loss: 0.004304

Final Test Loss: 0.006682
[ ]:

Cosine Annealing Schedule

Cosine annealing provides a smooth decay that follows a cosine curve. This schedule starts with the initial learning rate and gradually decreases to zero (or a minimum value) following a cosine function. It can lead to better convergence in some cases.

[6]:
config_cosine = {
    'learning_rate': 1e-3,
    'schedule_type': 'cosine',
    'decay_steps': 2000
}

train_losses_cosine, test_loss_cosine = run_adam_experiment(config_cosine)
Epoch [500/2000], Loss: 0.072498
Epoch [1000/2000], Loss: 0.022263
Epoch [1500/2000], Loss: 0.015320
Epoch [2000/2000], Loss: 0.014358

Final Test Loss: 0.021979
[ ]:

Warmup Strategy

Warmup is a technique where the learning rate starts at zero and linearly increases to the target learning rate over a specified number of steps. This can help stabilize training in the early stages, especially for complex models or difficult optimization landscapes.

After the warmup period, the learning rate follows the specified schedule (e.g., exponential decay).

[7]:
config_warmup = {
    'learning_rate': 1e-3,
    'schedule_type': 'exponential',
    'decay_steps': 1000,
    'decay_rate': 0.9,
    'warmup_steps': 500
}

train_losses_warmup, test_loss_warmup = run_adam_experiment(config_warmup)
Epoch [500/2000], Loss: 0.432106
Epoch [1000/2000], Loss: 0.046055
Epoch [1500/2000], Loss: 0.018493
Epoch [2000/2000], Loss: 0.008937

Final Test Loss: 0.015691
[ ]:

Part 2: L-BFGS Optimizer

L-BFGS (Limited-memory Broyden-Fletcher-Goldfarb-Shanno) is a quasi-Newton optimization method that can converge faster than first-order methods like Adam for smooth optimization problems. However, it requires special handling in Flax NNX due to its line search mechanism.

Key differences from Adam:

  1. L-BFGS uses a line search to find optimal step sizes

  2. The optimizer’s update() method requires value and value_fn arguments

  3. Flax NNX automatically handles the split/merge operations needed for the line search

Experiment Setup for L-BFGS

The training loop for L-BFGS differs from Adam in how we call the update() method. We must provide the current loss value and a function to evaluate the loss.

[8]:
def run_lbfgs_experiment(lbfgs_config, num_epochs=500, verbose=True):
    """
    Run a training experiment with L-BFGS optimizer.

    L-BFGS requires special handling: the update method needs both
    the current loss value and a value_fn to evaluate loss at different points.

    Args:
        lbfgs_config: Dictionary with L-BFGS optimizer parameters
        num_epochs: Number of training epochs
        verbose: Whether to print progress

    Returns:
        train_losses: Array of training losses
        test_loss: Final test loss
    """
    # Initialize a KAN model
    n_in = X_train.shape[1]
    n_out = y_train.shape[1]
    n_hidden = 6

    layer_dims = [n_in, n_hidden, n_hidden, n_out]
    req_params = {'D': 5, 'flavor': 'exact'}

    model = KAN(layer_dims=layer_dims,
                layer_type='chebyshev',
                required_parameters=req_params,
                seed=42)

    # Get L-BFGS optimizer using the utility function
    opt_type = get_lbfgs(**lbfgs_config)
    optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)

    # Define loss function that takes the model
    def loss_fn(model):
        residual = model(X_train) - y_train
        loss = jnp.mean((residual)**2)
        return loss

    # Define train step for L-BFGS
    @nnx.jit
    def train_step_lbfgs(model, optimizer):
        # Compute loss and gradients
        loss, grads = nnx.value_and_grad(loss_fn)(model)

        # Update with L-BFGS
        # IMPORTANT: Must pass value and value_fn for line search
        # Flax NNX automatically handles split/merge internally
        optimizer.update(model, grads, value=loss, value_fn=loss_fn)

        return loss

    # Training loop
    train_losses = jnp.zeros((num_epochs,))

    for epoch in range(num_epochs):
        loss = train_step_lbfgs(model, optimizer)
        train_losses = train_losses.at[epoch].set(loss)

        if verbose and (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.6f}')

    # Evaluate on test set
    y_pred = model(X_test)
    test_loss = jnp.mean((y_pred - y_test)**2)

    if verbose:
        print(f'\nFinal Test Loss: {test_loss:.6f}')

    return train_losses, test_loss
[ ]:

L-BFGS with Default Parameters

Let’s train with L-BFGS using default parameters. Note that L-BFGS typically converges in fewer iterations than Adam.

[9]:
config_lbfgs = {
    'memory_size': 10
}

train_losses_lbfgs, test_loss_lbfgs = run_lbfgs_experiment(config_lbfgs, num_epochs=100)
Epoch [100/100], Loss: 0.000343

Final Test Loss: 0.000483
[ ]:

L-BFGS with Larger Memory

Increasing the memory size allows L-BFGS to store more past gradients, which can improve the Hessian approximation at the cost of more memory usage.

[10]:
config_lbfgs_large = {
    'memory_size': 20
}

train_losses_lbfgs_large, test_loss_lbfgs_large = run_lbfgs_experiment(config_lbfgs_large, num_epochs=100)
Epoch [100/100], Loss: 0.000259

Final Test Loss: 0.000384
[ ]: