Source code for jaxkan.layers.utils

import jax
import jax.numpy as jnp


[docs] @jax.jit def solve_single_lstsq(A_single, B_single): """ Simulates linalg.lstsq by reformulating the problem AX = B via the normal equations: (A^T A) X = A^T B. This is used instead of linalg.lstsq because it's much faster. Args: A_single (jnp.array): Matrix A of AX = B, shape (M, N). B_single (jnp.array): Matrix B of AX = B, shape (M, K). Returns: single_solution (jnp.array): Matrix X of AX = B, shape (N, K). Example: >>> A = jnp.array([[2.0, 1.0], [1.0, 3.0]]) >>> B = jnp.array([[1.0], [2.0]]) >>> >>> solution = solve_single_lstsq(A, B) """ AtA = jnp.dot(A_single.T, A_single) AtB = jnp.dot(A_single.T, B_single) single_solution = jax.scipy.linalg.solve(AtA, AtB, assume_a='pos') return single_solution
[docs] @jax.jit def solve_full_lstsq(A_full, B_full): """ Parallelizes the single case, so that the problem can be solved for matrices with dimensions higher than 2. Essentially, solve_single_lstsq and solve_full_lstsq combined are a workaround, because (unlike PyTorch for example), JAX's libraries do not support lstsq for dims > 2. Args: A_full (jnp.array): Matrix A of AX = B, shape (batch, M, N). B_full (jnp.array): Matrix B of AX = B, shape (batch, M, K). Returns: full_solution (jnp.array): Matrix X of AX = B, shape (batch, N, K). Example: >>> A = jnp.array([[[2.0, 1.0], [1.0, 3.0]], [[1.0, 2.0], [2.0, 1.0]]]) >>> B = jnp.array([[[1.0], [2.0]], [[2.0], [3.0]]]) >>> >>> solution = solve_full_lstsq(A, B) """ # Define the solver for (*, ., .) dimensions solve_full = jax.vmap(solve_single_lstsq, in_axes=(0,0)) # Apply it to get back the coefficients full_solution = solve_full(A_full, B_full) return full_solution
[docs] def interpolate_moments(mu_old, nu_old, new_shape): """ Performs a linear interpolation to assign values to the first and second-order moments of gradients of the c_i basis functions coefficients after grid extension. Args: mu_old (jnp.array): First-order moments before extension, shape (n_in*n_out, num_basis) or (n_out, n_in, num_basis). nu_old (jnp.array): Second-order moments before extension, shape (n_in*n_out, num_basis) or (n_out, n_in, num_basis). new_shape (tuple): The new desired shape, either (n_in*n_out, new_num_basis) or (n_out, n_in, new_num_basis). Returns: mu_new, nu_new (tuple): First- and second-order moments after extension, shape new_shape. Example: >>> mu_old = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> nu_old = jnp.array([[7, 8, 9], [10, 11, 12]]) >>> new_shape = (2, 5) >>> >>> mu_new, nu_new = interpolate_moments(mu_old, nu_old, new_shape) """ old_shape = mu_old.shape old_j = old_shape[-1] # This is the dimension along which interpolation occurs new_j = new_shape[-1] # At this point, the shape will be either (n_in*n_out, num_basis) if the layer type is 'base' # or (n_out, n_in, num_basis) if the layer type is 'spline' # So we need a generic approach to handle these two types without control statements # Flatten all leading dimensions into a single dimension batch_size = jnp.prod(jnp.array(old_shape[:-1])) # Reshape to (n_in*n_out, num_basis) mu_old_2d = mu_old.reshape((batch_size, old_j)) nu_old_2d = nu_old.reshape((batch_size, old_j)) # Interpolate along the last dimension old_indices = jnp.linspace(0, old_j - 1, old_j) new_indices = jnp.linspace(0, old_j - 1, new_j) # Vectorize the interpolation function for use with vmap interpolate_fn = lambda old_row: jnp.interp(new_indices, old_indices, old_row) # Apply the interpolation function to each row using vmap mu_new_2d = jax.vmap(interpolate_fn)(mu_old_2d) nu_new_2d = jax.vmap(interpolate_fn)(nu_old_2d) # Reshape back to match the original leading dimensions # but with updated last dimension = new_j mu_new = mu_new_2d.reshape((*old_shape[:-1], new_j)) nu_new = nu_new_2d.reshape((*old_shape[:-1], new_j)) return mu_new, nu_new
[docs] def adam_transition(old_state, model_state): """ Performs the state transition for the Adam optimizer with scheduler after grid extension. Note that the transition happens in-place, i.e. nothing is returned, the optimizer is simply transitioned from the old state to the new. Args: old_state (tuple): Collection of Adam state and scheduler state before extension. model_state (dict): Dict of KAN model state after split. Example: >>> old_state = optimizer.opt_state >>> _, model_state = nnx.split(model) >>> adam_transition(old_state, model_state) """ # Get old state adam_mu, adam_nu = old_state[0].mu, old_state[0].nu for key in range(len(adam_mu['layers'])): # Find the c_basis shape for this layer c_shape = model_state['layers'][key]['c_basis'][...].shape # Get new mu and nu mu_old = adam_mu['layers'][key]['c_basis'][...] nu_old = adam_nu['layers'][key]['c_basis'][...] mu_new, nu_new = interpolate_moments(mu_old, nu_old, c_shape) # Set them using set_value for assignment adam_mu['layers'][key]['c_basis'].set_value(mu_new) adam_nu['layers'][key]['c_basis'].set_value(nu_new)