Tutorial 11 - Optimizer Utilities
In this tutorial we will explore the optimizer utility functions provided in jaxKAN. We will cover two main optimizers: Adam with various learning rate schedules, and L-BFGS which requires special handling in Flax NNX. These optimizers are particularly useful for training KANs and PIKANs where adaptive learning rates can significantly improve convergence.
[1]:
from jaxkan.models.KAN import KAN
from jaxkan.models.utils import get_adam, get_lbfgs
import jax
import jax.numpy as jnp
from sklearn.model_selection import train_test_split
from flax import nnx
import optax
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
[ ]:
Data Generation
We will use the same function fitting problem from Tutorial 1 to compare different optimizer configurations. Consider the function \(f(x, y) = x^2 + 2\exp(y)\), which we will fit using a KAN model with different optimization strategies.
[2]:
def f(x, y):
return x**2 + 2*jnp.exp(y)
def generate_data(minval=-1, maxval=1, num_samples=1000, seed=42):
key = jax.random.PRNGKey(seed)
x_key, y_key = jax.random.split(key)
x1 = jax.random.uniform(x_key, shape=(num_samples,), minval=minval, maxval=maxval)
x2 = jax.random.uniform(y_key, shape=(num_samples,), minval=minval, maxval=maxval)
y = f(x1, x2).reshape(-1, 1)
X = jnp.stack([x1, x2], axis=1)
return X, y
seed = 42
X, y = generate_data(minval=-1, maxval=1, num_samples=1000, seed=seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)
print("Training set size:", X_train.shape)
print("Test set size:", X_test.shape)
Training set size: (800, 2)
Test set size: (200, 2)
[ ]:
Part 1: Adam Optimizer
The Adam optimizer is the most commonly used optimizer for training neural networks, including KANs. The get_adam function provides a convenient interface for creating Adam optimizers with various learning rate schedules and warmup strategies.
Experiment Setup for Adam
We define a function to encapsulate the training loop. This function will allow us to easily experiment with different Adam configurations.
[3]:
def run_adam_experiment(adam_config, num_epochs=2000, verbose=True):
"""
Run a training experiment with Adam optimizer.
Args:
adam_config: Dictionary with Adam optimizer parameters
num_epochs: Number of training epochs
verbose: Whether to print progress
Returns:
train_losses: Array of training losses
test_loss: Final test loss
"""
# Initialize a KAN model
n_in = X_train.shape[1]
n_out = y_train.shape[1]
n_hidden = 6
layer_dims = [n_in, n_hidden, n_hidden, n_out]
req_params = {'D': 5, 'flavor': 'exact'}
model = KAN(layer_dims=layer_dims,
layer_type='chebyshev',
required_parameters=req_params,
seed=42)
# Get Adam optimizer using the utility function
opt_type = get_adam(**adam_config)
optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)
# Define train step
@nnx.jit
def train_step(model, optimizer, X_train, y_train):
def loss_fn(model):
residual = model(X_train) - y_train
loss = jnp.mean((residual)**2)
return loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
# Training loop
train_losses = jnp.zeros((num_epochs,))
for epoch in range(num_epochs):
loss = train_step(model, optimizer, X_train, y_train)
train_losses = train_losses.at[epoch].set(loss)
if verbose and (epoch + 1) % 500 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.6f}')
# Evaluate on test set
y_pred = model(X_test)
test_loss = jnp.mean((y_pred - y_test)**2)
if verbose:
print(f'\nFinal Test Loss: {test_loss:.6f}')
return train_losses, test_loss
[ ]:
Constant Learning Rate
We begin with the simplest case: a constant learning rate. This is the default behavior when no schedule is specified.
[4]:
config_constant = {
'learning_rate': 1e-3
}
train_losses_constant, test_loss_constant = run_adam_experiment(config_constant)
Epoch [500/2000], Loss: 0.064715
Epoch [1000/2000], Loss: 0.016274
Epoch [1500/2000], Loss: 0.007237
Epoch [2000/2000], Loss: 0.003581
Final Test Loss: 0.005624
[ ]:
Exponential Decay Schedule
A common strategy is to use exponential decay, where the learning rate decreases exponentially over time. This can help the model converge to a better solution by taking smaller steps as training progresses.
The learning rate at step \(t\) is given by:
where \(\text{lr}_0\) is the initial learning rate, \(\gamma\) is the decay rate, and \(T\) is the number of decay steps.
[5]:
config_exp_decay = {
'learning_rate': 1e-3,
'schedule_type': 'exponential',
'decay_steps': 1000,
'decay_rate': 0.9
}
train_losses_exp, test_loss_exp = run_adam_experiment(config_exp_decay)
Epoch [500/2000], Loss: 0.068539
Epoch [1000/2000], Loss: 0.017559
Epoch [1500/2000], Loss: 0.008215
Epoch [2000/2000], Loss: 0.004304
Final Test Loss: 0.006682
[ ]:
Cosine Annealing Schedule
Cosine annealing provides a smooth decay that follows a cosine curve. This schedule starts with the initial learning rate and gradually decreases to zero (or a minimum value) following a cosine function. It can lead to better convergence in some cases.
[6]:
config_cosine = {
'learning_rate': 1e-3,
'schedule_type': 'cosine',
'decay_steps': 2000
}
train_losses_cosine, test_loss_cosine = run_adam_experiment(config_cosine)
Epoch [500/2000], Loss: 0.072498
Epoch [1000/2000], Loss: 0.022263
Epoch [1500/2000], Loss: 0.015320
Epoch [2000/2000], Loss: 0.014358
Final Test Loss: 0.021979
[ ]:
Warmup Strategy
Warmup is a technique where the learning rate starts at zero and linearly increases to the target learning rate over a specified number of steps. This can help stabilize training in the early stages, especially for complex models or difficult optimization landscapes.
After the warmup period, the learning rate follows the specified schedule (e.g., exponential decay).
[7]:
config_warmup = {
'learning_rate': 1e-3,
'schedule_type': 'exponential',
'decay_steps': 1000,
'decay_rate': 0.9,
'warmup_steps': 500
}
train_losses_warmup, test_loss_warmup = run_adam_experiment(config_warmup)
Epoch [500/2000], Loss: 0.432106
Epoch [1000/2000], Loss: 0.046055
Epoch [1500/2000], Loss: 0.018493
Epoch [2000/2000], Loss: 0.008937
Final Test Loss: 0.015691
[ ]:
Part 2: L-BFGS Optimizer
L-BFGS (Limited-memory Broyden-Fletcher-Goldfarb-Shanno) is a quasi-Newton optimization method that can converge faster than first-order methods like Adam for smooth optimization problems. However, it requires special handling in Flax NNX due to its line search mechanism.
Key differences from Adam:
L-BFGS uses a line search to find optimal step sizes
The optimizer’s
update()method requiresvalueandvalue_fnargumentsFlax NNX automatically handles the
split/mergeoperations needed for the line search
Experiment Setup for L-BFGS
The training loop for L-BFGS differs from Adam in how we call the update() method. We must provide the current loss value and a function to evaluate the loss.
[8]:
def run_lbfgs_experiment(lbfgs_config, num_epochs=500, verbose=True):
"""
Run a training experiment with L-BFGS optimizer.
L-BFGS requires special handling: the update method needs both
the current loss value and a value_fn to evaluate loss at different points.
Args:
lbfgs_config: Dictionary with L-BFGS optimizer parameters
num_epochs: Number of training epochs
verbose: Whether to print progress
Returns:
train_losses: Array of training losses
test_loss: Final test loss
"""
# Initialize a KAN model
n_in = X_train.shape[1]
n_out = y_train.shape[1]
n_hidden = 6
layer_dims = [n_in, n_hidden, n_hidden, n_out]
req_params = {'D': 5, 'flavor': 'exact'}
model = KAN(layer_dims=layer_dims,
layer_type='chebyshev',
required_parameters=req_params,
seed=42)
# Get L-BFGS optimizer using the utility function
opt_type = get_lbfgs(**lbfgs_config)
optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)
# Define loss function that takes the model
def loss_fn(model):
residual = model(X_train) - y_train
loss = jnp.mean((residual)**2)
return loss
# Define train step for L-BFGS
@nnx.jit
def train_step_lbfgs(model, optimizer):
# Compute loss and gradients
loss, grads = nnx.value_and_grad(loss_fn)(model)
# Update with L-BFGS
# IMPORTANT: Must pass value and value_fn for line search
# Flax NNX automatically handles split/merge internally
optimizer.update(model, grads, value=loss, value_fn=loss_fn)
return loss
# Training loop
train_losses = jnp.zeros((num_epochs,))
for epoch in range(num_epochs):
loss = train_step_lbfgs(model, optimizer)
train_losses = train_losses.at[epoch].set(loss)
if verbose and (epoch + 1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.6f}')
# Evaluate on test set
y_pred = model(X_test)
test_loss = jnp.mean((y_pred - y_test)**2)
if verbose:
print(f'\nFinal Test Loss: {test_loss:.6f}')
return train_losses, test_loss
[ ]:
L-BFGS with Default Parameters
Let’s train with L-BFGS using default parameters. Note that L-BFGS typically converges in fewer iterations than Adam.
[9]:
config_lbfgs = {
'memory_size': 10
}
train_losses_lbfgs, test_loss_lbfgs = run_lbfgs_experiment(config_lbfgs, num_epochs=100)
Epoch [100/100], Loss: 0.000343
Final Test Loss: 0.000483
[ ]:
L-BFGS with Larger Memory
Increasing the memory size allows L-BFGS to store more past gradients, which can improve the Hessian approximation at the cost of more memory usage.
[10]:
config_lbfgs_large = {
'memory_size': 20
}
train_losses_lbfgs_large, test_loss_lbfgs_large = run_lbfgs_experiment(config_lbfgs_large, num_epochs=100)
Epoch [100/100], Loss: 0.000259
Final Test Loss: 0.000384
[ ]: