import jax.numpy as jnp
[docs]
class SplineGrid:
"""
SplineGrid class, corresponding to the grid of the SplineLayer class. It comprises an initialization as well as an update procedure.
Attributes:
n_nodes (int):
Number of layer nodes.
k (int):
Order of the spline basis functions.
G (int):
Number of grid intervals.
grid_range (tuple):
An initial range for the grid's ends, although adaptivity can completely change it.
grid_e (float):
Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
item (jnp.array):
The actual grid array, shape (n_nodes, G + 2k + 1).
"""
[docs]
def __init__(self, n_nodes: int = 2, k: int = 3, G: int = 3, grid_range: tuple = (-1,1), grid_e: float = 0.05):
"""
Initializes a SplineGrid instance.
Args:
n_nodes (int):
Number of layer nodes.
k (int):
Order of the spline basis functions.
G (int):
Number of grid intervals.
grid_range (tuple):
An initial range for the grid's ends, although adaptivity can completely change it.
grid_e (float):
Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
Example:
>>> grid_type = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid = grid_type.item
"""
self.n_nodes = n_nodes
self.k = k
self.G = G
self.grid_range = grid_range
self.grid_e = grid_e
# Initialize the grid, which is henceforth callable as .item
self.item = self._initialize()
def _initialize(self):
"""
Create and initialize the grid. Can also be used to reset a grid to the default value.
Returns:
grid (jnp.array):
Grid for the SplineLayer, shape (n_nodes, G + 2k + 1).
Example:
>>> grid = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid.item = grid._initialize()
"""
# Calculate the step size for the knot vector based on its end values
h = (self.grid_range[1] - self.grid_range[0]) / self.G
# Create the initial knot vector and perform augmentation
# Now it is expanded from G+1 points to G+1 + 2k points, because k points are appended at each of its ends
grid = jnp.arange(-self.k, self.G + self.k + 1, dtype=jnp.float32) * h + self.grid_range[0]
# Expand across nodes - the shape becomes (n_nodes, G + 2k + 1), with a grid corresponding to each node
grid = jnp.expand_dims(grid, axis=0).repeat(self.n_nodes, axis=0)
return grid
[docs]
def update(self, x, G_new):
"""
Update the grid based on input data and new grid size.
Args:
x (jnp.ndarray):
Input data, shape (batch, n_nodes).
G_new (int):
New grid size in terms of intervals.
Example:
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> grid = SplineGrid(n_nodes = 2, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid.update(x=x_batch, G_new=5)
"""
batch = x.shape[0]
# Sort inputs
x_sorted = jnp.sort(x, axis=0)
# Get an adaptive grid of size G_new + 1
# Essentially we sample points from x, based on their density
ids = jnp.concatenate((jnp.floor(batch / G_new * jnp.arange(G_new)).astype(int), jnp.array([-1])))
grid_adaptive = x_sorted[ids]
# Get a uniform grid of size G_new + 1
# Essentially we only consider the maximum and minimum values of x
margin = 0.01
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / G_new
grid_uniform = (
jnp.arange(G_new + 1, dtype=jnp.float32)[:, None]
* uniform_step
+ x_sorted[0]
- margin
)
# Perform a linear mixing of the two grid types
grid = self.grid_e * grid_uniform + (1.0 - self.grid_e) * grid_adaptive
# Perform grid augmentation, so that the grid is extended from G_new + 1 to G_new + 2k + 1 points
# First get a new step vector
h = (grid[-1] - grid[0]) / G_new
# Then calculate the left and right additions in terms of h
left = h * jnp.arange(self.k, 0, -1)[:, None]
right = h * jnp.arange(1, self.k + 1)[:, None]
# Finally, concatenate left and right
grid = jnp.concatenate(
[
grid[:1] - left,
grid,
grid[-1:] + right
],
axis = 0,
)
# Update the grid value and size
self.item = grid.T
self.G = G_new