Source code for jaxkan.pikan.adaptive

import jax
import jax.numpy as jnp

from flax import nnx

import optax


DEFAULT_EPS = 1e-12


[docs] def get_colloc_indices(collocs_pool, batch_size, px, seed): """ Sample collocation point indices from a pool based on probability weights and sort by first coordinate (time). Args: collocs_pool (jnp.array): Pool of collocation points, shape (N, n_dims). batch_size (int): Number of points to sample. px (jnp.array): Probability weights for sampling, shape (N,). seed (int): Random seed for reproducibility. Returns: sorted_pool_indices (jnp.array): Indices of sampled points sorted by first coordinate, shape (batch_size,). Example: >>> collocs_pool = jnp.array([[0.5, 0.1], [0.2, 0.3], [0.8, 0.7]]) >>> px = jnp.array([0.5, 0.3, 0.2]) >>> indices = get_colloc_indices(collocs_pool, batch_size=2, px=px, seed=42) """ collocs_key = jax.random.PRNGKey(seed) X_ids = jax.random.choice(key=collocs_key, a=collocs_pool.shape[0], shape=(batch_size,), replace=False, p=px) sorted_batch_order = jnp.argsort(collocs_pool[X_ids, 0]) sorted_pool_indices = X_ids[sorted_batch_order] return sorted_pool_indices
[docs] @nnx.jit def update_rba_weights(residuals, weights, gamma, eta, eps=DEFAULT_EPS): """ Update residual-based attention weights from the current residual magnitudes. Args: residuals (jnp.array): Residual values, typically shape (N, 1). weights (jnp.array): Current RBA weights with the same shape as ``residuals``. gamma (float): Exponential moving average coefficient. eta (float): Residual scaling coefficient. eps (float): Numerical stabilizer for normalization. Returns: jnp.array: Updated RBA weights. """ abs_res = jnp.abs(residuals) scale = jnp.max(abs_res) return (gamma * weights) + (eta * abs_res / (scale + eps))
[docs] @nnx.jit def apply_rba_weights(residuals, weights): """ Apply RBA weights to residuals while keeping the weights outside the backward graph. Args: residuals (jnp.array): Residual values. weights (jnp.array): RBA weights with matching shape. Returns: jnp.array: Weighted residuals. """ return jax.lax.stop_gradient(weights) * residuals
[docs] @nnx.jit def get_causal_weights(losses, causal_matrix, causal_tol): """ Compute causal training weights and stop their gradients. Args: losses (jnp.array): Chunk-wise loss values. causal_matrix (jnp.array): Upper-triangular causal coupling matrix. causal_tol (float): Exponential weighting coefficient. Returns: jnp.array: Stopped-gradient causal weights. """ return jax.lax.stop_gradient(jnp.exp(-causal_tol * (causal_matrix @ losses)))
[docs] @nnx.jit def get_rad_probabilities(weighted_residuals, rad_a, rad_c, eps=DEFAULT_EPS): """ Convert weighted residuals into normalized RAD sampling probabilities. Args: weighted_residuals (jnp.array): Residuals already multiplied by their current adaptive weights. rad_a (float): Density-control exponent. rad_c (float): Baseline offset added to the density. eps (float): Numerical stabilizer for normalization. Returns: jnp.array: Probability vector of shape (N,). """ abs_weighted = jnp.abs(weighted_residuals) density = jnp.power(abs_weighted, rad_a) probs = (density / (jnp.mean(density) + eps)) + rad_c return (probs / jnp.sum(probs))[:, 0]
[docs] @nnx.jit(static_argnames=("batch_size",)) def get_rad_indices(collocs_pool, residuals, old_indices, batch_weights, pool_weights, batch_size, rad_a, rad_c, seed, eps=DEFAULT_EPS): """ Update pool weights and draw a new RAD-resampled batch of collocation indices. Args: collocs_pool (jnp.array): Full collocation pool, shape (N, d). residuals (jnp.array): Residual values over the full pool, shape (N, 1). old_indices (jnp.array): Previously active batch indices. batch_weights (jnp.array): Current adaptive weights for the active batch. pool_weights (jnp.array): Full pool of adaptive weights. batch_size (int): Number of indices to sample. rad_a (float): Density-control exponent. rad_c (float): Baseline offset added to the density. seed (int): Sampling seed. eps (float): Numerical stabilizer. Returns: tuple[jnp.array, jnp.array, jnp.array]: Sampled indices, updated full-pool weights, and normalized probabilities. """ updated_pool = pool_weights.at[old_indices].set(batch_weights) weighted_residuals = updated_pool * residuals px = get_rad_probabilities(weighted_residuals, rad_a, rad_c, eps) sorted_indices = get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px, seed=seed) return sorted_indices, updated_pool, px
@nnx.jit def _lr_anneal_impl(grads, lambdas, grad_mixing): norms = jnp.stack(tuple(optax.tree.norm(grad) for grad in grads)) norm_sum = jnp.sum(norms) hats = norm_sum / (norms + 1e-5 * norm_sum) return tuple( grad_mixing * lambdas[i] + (1.0 - grad_mixing) * hats[i] for i in range(len(lambdas)) )
[docs] def lr_anneal(grads, lambdas, grad_mixing): """ Perform the learning rate annealing algorithm introduced in "Understanding and mitigating gradient pathologies in physics-informed neural networks" Args: grads (tuple | list): Tuple/list of gradient pytrees, one per loss term. lambdas (tuple | list): Tuple/list of current global loss weights, matching ``grads`` in order. grad_mixing (float): Exponential moving average coefficient. Returns: tuple: Updated global loss weights in the same logical order as the inputs. Example: >>> λs_new = lr_anneal((grads_E, grads_B, grads_aux), (1.0, 1.0, 1.0), 0.9) """ grads = tuple(grads) lambdas = tuple(lambdas) if len(grads) != len(lambdas): raise ValueError("grads and lambdas must have the same number of terms.") if len(grads) == 0: raise ValueError("lr_anneal requires at least one gradient/weight pair.") return _lr_anneal_impl(grads, lambdas, grad_mixing)