from jax import numpy as jnp
[docs]
class BaseGrid:
"""
BaseGrid class, corresponding to the grid of the BaseLayer class. It comprises an initialization as well as an update procedure.
Attributes:
n_in (int):
Number of layer's incoming nodes.
n_out (int):
Number of layer's outgoing nodes.
k (int):
Order of the spline basis functions.
G (int):
Number of grid intervals.
grid_range (tuple):
An initial range for the grid's ends, although adaptivity can completely change it.
grid_e (float):
Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
item (jnp.array):
The actual grid array, shape (n_in*n_out, G + 2k + 1).
"""
[docs]
def __init__(self, n_in: int = 2, n_out: int = 5, k: int = 3,
G: int = 3, grid_range: tuple = (-1,1), grid_e: float = 0.05
):
"""
Initializes a BaseGrid instance.
Args:
n_in (int):
Number of layer's incoming nodes.
n_out (int):
Number of layer's outgoing nodes.
k (int):
Order of the spline basis functions.
G (int):
Number of grid intervals.
grid_range (tuple):
An initial range for the grid's ends, although adaptivity can completely change it.
grid_e (float):
Parameter that defines if the grids are uniform (grid_e = 1.0) or sample-dependent (grid_e = 0.0). Intermediate values correspond to a linear mixing of the two cases.
Example:
>>> grid_type = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid = grid_type.item
"""
self.n_in = n_in
self.n_out = n_out
self.k = k
self.G = G
self.grid_range = grid_range
self.grid_e = grid_e
# Initialize the grid, which is henceforth callable as .item
self.item = self._initialize()
def _initialize(self):
"""
Create and initialize the grid. Can also be used to reset a grid to the default value.
Returns:
grid (jnp.array):
Grid for the BaseLayer, shape (n_in*n_out, G + 2k + 1).
Example:
>>> grid = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid.item = grid._initialize()
"""
# Calculate the step size for the knot vector based on its end values
h = (self.grid_range[1] - self.grid_range[0]) / self.G
# Create the initial knot vector and perform augmentation
# Now it is expanded from G+1 points to G+1 + 2k points, because k points are appended at each of its ends
grid = jnp.arange(-self.k, self.G + self.k + 1, dtype=jnp.float32) * h + self.grid_range[0]
# Expand for broadcasting - the shape becomes (n_in*n_out, G + 2k + 1), so that the grid
# can be passed in all n_in*n_out spline basis functions simultaneously
grid = jnp.expand_dims(grid, axis=0)
grid = jnp.tile(grid, (self.n_in*self.n_out, 1))
return grid
[docs]
def update(self, x, G_new):
"""
Update the grid based on input data and new grid size.
Args:
x (jnp.ndarray):
Input data, shape (batch, n_in).
G_new (int):
New grid size in terms of intervals.
Example:
>>> key = jax.random.key(42)
>>> x_batch = jax.random.uniform(key, shape=(100, 2), minval=-1.0, maxval=1.0)
>>>
>>> grid = BaseGrid(n_in = 2, n_out = 5, k = 3, G = 3, grid_range = (-1,1), grid_e = 0.05)
>>> grid.update(x=x_batch, G_new=5)
"""
batch = x.shape[0]
# Extend to shape (batch, n_in*n_out)
x_ext = jnp.einsum('ij,k->ikj', x, jnp.ones(self.n_out,)).reshape((batch, self.n_in * self.n_out))
# Transpose to shape (n_in*n_out, batch)
x_ext = jnp.transpose(x_ext, (1, 0))
# Sort inputs
x_sorted = jnp.sort(x_ext, axis=1)
# Get an adaptive grid of size G_new + 1
# Essentially we sample points from x, based on their density
ids = jnp.concatenate((jnp.floor(batch / G_new * jnp.arange(G_new)).astype(int), jnp.array([-1])))
grid_adaptive = x_sorted[:, ids]
# Get a uniform grid of size G_new + 1
# Essentially we only consider the maximum and minimum values of x
margin = 0.01
uniform_step = (x_sorted[:, -1] - x_sorted[:, 0] + 2 * margin) / G_new
grid_uniform = (
jnp.arange(G_new + 1, dtype=jnp.float32)
* uniform_step[:, None]
+ x_sorted[:, 0][:, None]
- margin
)
# Perform a linear mixing of the two grid types
grid = self.grid_e * grid_uniform + (1.0 - self.grid_e) * grid_adaptive
# Perform grid augmentation, so that the grid is extended from G_new + 1 to G_new + 2k + 1 points
# First get a new step vector
h = (grid[:, [-1]] - grid[:, [0]]) / G_new
# Then calculate the left and right additions in terms of h
left = jnp.squeeze((jnp.arange(self.k, 0, -1)*h[:,None]), axis=1)
right = jnp.squeeze((jnp.arange(1, self.k+1)*h[:,None]), axis=1)
# Finally, concatenate left and right
grid = jnp.concatenate(
[
grid[:, [0]] - left,
grid,
grid[:, [-1]] + right
],
axis=1,
)
# Update the grid value and size
self.item = grid
self.G = G_new