Source code for jaxkan.layers.__init__

from .Spline import BaseLayer, SplineLayer
from .Chebyshev import ChebyshevLayer
from .Legendre import LegendreLayer
from .Fourier import FourierLayer
from .RBF import RBFLayer
from .Sine import SineLayer


[docs] def get_layer(layer_type: str): """ Helper method that creates a mapping between layer type codes and the actual classes. Args: layer_type (str): Code of layer to be used. Returns: layer (jaxkan.layers.Layer): A jaxkan.layers layer class instance to be used as the building block of a KAN. Example: >>> LayerClass = get_layer("base") """ layer_map = { "base": BaseLayer, "spline": SplineLayer, "chebyshev": ChebyshevLayer, "legendre": LegendreLayer, "fourier": FourierLayer, "rbf": RBFLayer, "sine": SineLayer } if layer_type not in layer_map: raise ValueError(f"Unknown layer type: {layer_type}. Available types: {list(layer_map.keys())}") LayerClass = layer_map[layer_type] return LayerClass