Tutorial 8 - Adaptive PIKANs

In Tutorial 6 we got a taste of solving PDEs using KANs and in Tutorial 7 we started exploring adaptive training techniques. Building on this adaptive training idea, in this tutorial we will see how to adaptively train PIKANs, based on the findings of this and this paper.

[1]:
from jaxkan.models.KAN import KAN

import jax
import jax.numpy as jnp

from jaxkan.pikan.pde import get_ac_res
from jaxkan.pikan.sampling import get_collocs_grid
from jaxkan.pikan.adaptive import (
    apply_rba_weights,
    get_causal_weights,
    get_colloc_indices,
    get_rad_indices,
    lr_anneal,
    update_rba_weights,
)

from typing import Union, List

from flax import nnx
import optax

import matplotlib.pyplot as plt
import numpy as np

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
[ ]:

Data Generation

Burgers’ Equation was relatively easy to solve even without adaptive techniques, so in this case we will be solving the Allen-Cahn Equation,

\[\frac{\partial u}{\partial t} - D\frac{\partial^2 u}{\partial x^2} + 5 \left(u^3 - u\right) = 0,\]

for \(D = 10^{-4}\) in the \(\Omega = [0,1]\times [-1, 1]\) domain, subject to the boundary conditions

\[u\left(t=0, x\right) = x^2 \cos\left(\pi x\right),\]
\[u\left(t, x=-1\right) = u\left(t, x=1\right) = -1.\]

In the following, we must first define the corresponding collocation points. This time we will be creating a large pool of collocation points, from which we will be sampling batches by performing RAD resampling (see the two referenced papers for more information).

[2]:
seed = 42

# Generate Collocation points for PDE
collocs_pool = get_collocs_grid(ranges=[(0, 1, 2**7), (-1, 1, 2**7)])

# Generate Collocation points for IC
ic_collocs = get_collocs_grid(ranges=[(0, 0, 1), (-1, 1, 2**6)])
ic_data = ((ic_collocs[:,1]**2)*jnp.cos(jnp.pi*ic_collocs[:,1])).reshape(-1,1)
[ ]:

Custom KAN Model

As seen above, we have not defined any collocation points for the boundary conditions. This is because we intend to use what we learned in “DIY KANs” to define our own custom KAN model, which will directly enforce the boundary conditions through its architecture (see this paper for more details). In particular, we will define a wrapper class that inherits from KAN but adds an additional step in its forward pass.

[3]:
class KANWrapper(KAN):

    def __init__(self, layer_dims: List[int], layer_type: str = "base",
                 required_parameters: Union[None, dict] = None, seed: int = 42
                 ):

        self.model = KAN(layer_dims, layer_type, required_parameters, seed)


    def __call__(self, x):

        original_x = x

        y = self.model(x)

        # Impose BC u(t, -1) = u(t, 1) = -1
        x_coord = original_x[:, 1:2]
        # In this way, when x = -1 or when x = 1 the factor (1 - x_coord**2) nullifies the model's output and the -1 term leads to u = -1, as required
        y = (1 - x_coord**2) * y - 1.0

        return y
[4]:
# Initialize a KAN model
n_in = collocs_pool.shape[1]
n_out = 1
n_hidden = 12

layer_dims = [n_in, n_hidden, n_hidden, n_hidden, n_out]
req_params = {'D': 5, 'flavor': 'exact', 'residual': None, 'external_weights': False, 'init_scheme': {'type': 'glorot_fine'}, 'add_bias': True}

model = KANWrapper(layer_dims = layer_dims,
            layer_type = 'chebyshev',
            required_parameters = req_params,
            seed = seed
           )

[5]:
# We will also be using a more adaptive optimizer with learning rate scheduling
lr_schedule = optax.exponential_decay(
                init_value=1e-3,
                transition_steps=1000,
                decay_rate=0.9,
                staircase=False
            )

opt_type = optax.adam(learning_rate=lr_schedule, b1=0.9, b2=0.999, eps=1e-8)

optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)
[ ]:

Adaptive Training

Unlike in Tutorial 6, where we defined a simple MSE loss to train the network, here we will be using some additional adaptive training methods: RBA, RAD, learning rate annealing and causal training (for an in-depth look, read Training Deep Physics-Informed Kolmogorov–Arnold Networks).

[6]:
@nnx.jit
def get_rad_collocs(model, pde_collocs_pool, sorted_indices, l_pde, l_pde_pool):
    resids_pool = pde_res(model, pde_collocs_pool)
    new_indices, new_pool, _ = get_rad_indices(
        collocs_pool=pde_collocs_pool,
        residuals=resids_pool,
        old_indices=sorted_indices,
        batch_weights=l_pde,
        pool_weights=l_pde_pool,
        batch_size=batch_size,
        rad_a=rad_a,
        rad_c=rad_c,
        seed=seed,
    )
    return new_indices, new_pool
[7]:
# PDE Residual
pde_res = get_ac_res()

# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    l_E_new = update_rba_weights(residuals, l_E, gamma=RBA_gamma, eta=RBA_eta)

    # Multiply by RBA weights while keeping them out of the backward graph
    w_resids = apply_rba_weights(residuals, l_E_new)

    # Reshape residuals for causal training
    residuals = w_resids.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = get_causal_weights(loss, M, causal_tol)

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss, l_E_new


def ic_loss(model, l_I, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Get new RBA weights
    l_I_new = update_rba_weights(ic_res, l_I, gamma=RBA_gamma, eta=RBA_eta)

    # Multiply by RBA weights while keeping them out of the backward graph
    w_resids = apply_rba_weights(ic_res, l_I_new)

    # Loss
    loss = jnp.mean(w_resids**2)

    return loss, l_I_new


@nnx.jit(static_argnames=("compute_grads_sep",))
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I, compute_grads_sep=False):

    def total_loss_fn(model):
        (loss_E, l_E_new) = pde_loss(model, l_E, collocs)
        (loss_I, l_I_new) = ic_loss(model, l_I, ic_collocs, ic_data)
        total = λ_E * loss_E + λ_I * loss_I
        return total, (loss_E, loss_I, l_E_new, l_I_new)

    (loss, aux), grads = nnx.value_and_grad(total_loss_fn, has_aux=True)(model)

    sep_grads = (None, None)
    if compute_grads_sep:
        grads_E = nnx.grad(lambda m: pde_loss(m, l_E, collocs)[0])(model)
        grads_I = nnx.grad(lambda m: ic_loss(m, l_I, ic_collocs, ic_data)[0])(model)
        sep_grads = (grads_E, grads_I)

    optimizer.update(model, grads)

    return loss, aux, sep_grads
[8]:
num_epochs = 20_000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define LR Annealing parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01
[9]:
# Initialize RBA weights - full pool
l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
# Also get RBAs for ICs
l_I = jnp.ones((ic_collocs.shape[0], 1))

# Get starting collocation points & RBA weights
sorted_indices = get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)

pde_collocs = collocs_pool[sorted_indices]
l_E = l_E_pool[sorted_indices]

# Define global loss weights (initialization)
λ_E = jnp.array(1.0, dtype=float)
λ_I = jnp.array(1.0, dtype=float)

Following this setup, we proceed to train the model.

[10]:
train_losses = jnp.zeros((num_epochs,))

# Start training
for epoch in range(num_epochs):

    do_anneal = (epoch != 0) and (epoch % f_grad_norm == 0)

    loss, aux, sep_grads = train_step(
        model, optimizer, pde_collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I,
        compute_grads_sep=do_anneal
    )

    loss_E, loss_I, l_E, l_I = aux

    # Perform lr annealing
    if do_anneal:

        print(f"Epoch No. {epoch}. Current loss: {loss:.2e}. Performing learning-rate annealing.")

        λ_E, λ_I = lr_anneal((sep_grads[0], sep_grads[1]), (λ_E, λ_I), grad_mixing)

    # Perform RAD
    if (epoch != 0) and (epoch % f_resample == 0):

        print(f"Epoch No. {epoch}. Current loss: {loss:.2e}. Performing RAD resampling.")

        sorted_indices, l_E_pool = get_rad_collocs(
            model, collocs_pool, sorted_indices, l_E, l_E_pool
        )

        # Set new batch of collocs and l_E
        pde_collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]

    # Append the loss
    train_losses = train_losses.at[epoch].set(loss)
Epoch No. 1000. Current loss: 1.80e-02. Performing learning-rate annealing.
Epoch No. 2000. Current loss: 3.39e-02. Performing learning-rate annealing.
Epoch No. 2000. Current loss: 3.39e-02. Performing RAD resampling.
Epoch No. 3000. Current loss: 4.19e-02. Performing learning-rate annealing.
Epoch No. 4000. Current loss: 4.89e-02. Performing learning-rate annealing.
Epoch No. 4000. Current loss: 4.89e-02. Performing RAD resampling.
Epoch No. 5000. Current loss: 4.17e-02. Performing learning-rate annealing.
Epoch No. 6000. Current loss: 4.21e-03. Performing learning-rate annealing.
Epoch No. 6000. Current loss: 4.21e-03. Performing RAD resampling.
Epoch No. 7000. Current loss: 3.18e-03. Performing learning-rate annealing.
Epoch No. 8000. Current loss: 3.07e-03. Performing learning-rate annealing.
Epoch No. 8000. Current loss: 3.07e-03. Performing RAD resampling.
Epoch No. 9000. Current loss: 2.66e-03. Performing learning-rate annealing.
Epoch No. 10000. Current loss: 2.53e-03. Performing learning-rate annealing.
Epoch No. 10000. Current loss: 2.53e-03. Performing RAD resampling.
Epoch No. 11000. Current loss: 2.85e-03. Performing learning-rate annealing.
Epoch No. 12000. Current loss: 2.42e-03. Performing learning-rate annealing.
Epoch No. 12000. Current loss: 2.42e-03. Performing RAD resampling.
Epoch No. 13000. Current loss: 2.58e-03. Performing learning-rate annealing.
Epoch No. 14000. Current loss: 2.26e-03. Performing learning-rate annealing.
Epoch No. 14000. Current loss: 2.26e-03. Performing RAD resampling.
Epoch No. 15000. Current loss: 1.80e-03. Performing learning-rate annealing.
Epoch No. 16000. Current loss: 1.40e-03. Performing learning-rate annealing.
Epoch No. 16000. Current loss: 1.40e-03. Performing RAD resampling.
Epoch No. 17000. Current loss: 1.26e-03. Performing learning-rate annealing.
Epoch No. 18000. Current loss: 1.15e-03. Performing learning-rate annealing.
Epoch No. 18000. Current loss: 1.15e-03. Performing RAD resampling.
Epoch No. 19000. Current loss: 1.02e-03. Performing learning-rate annealing.
[ ]:

Evaluation

By visualizing the train loss curve, we indeed see how the adaptive training techniques implemented lead to a significantly small training error by the end of training.

[11]:
plt.figure(figsize=(7, 4))

plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='#25599c', markersize=1)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.yscale('log')

plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)

plt.show()
../_images/tutorials_Tutorial_8_-_Adaptive_PIKANs_25_0.png

Additionally, we observe a good approximation to the actual solution to the Allen-Cahn equation, which cannot be obtained without utilizing any adaptive technique.

[12]:
N_t, N_x = 100, 256

t = np.linspace(0.0, 1.0, N_t)
x = np.linspace(-1.0, 1.0, N_x)
T, X = np.meshgrid(t, x, indexing='ij')
coords = np.stack([T.flatten(), X.flatten()], axis=1)

output = model(jnp.array(coords))
resplot = np.array(output).reshape(N_t, N_x)

plt.figure(figsize=(7, 4))
plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')
plt.colorbar()

plt.title('Solution of Allen-Cahn Equation')
plt.xlabel('t')

plt.ylabel('x')

plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_8_-_Adaptive_PIKANs_27_0.png
[ ]: