About

jaxKAN is a Python package designed to enable the training of Kolmogorov-Arnold Networks (KANs) using the JAX framework. Built on Flax’s NNX module, jaxKAN provides a collection of KAN layers that serve as foundational building blocks for various KAN architectures, such as the EfficientKAN and the ChebyKAN. While it includes standard features like initialization and forward pass methods, the KAN class in jaxKAN introduces an extend_grids method, which facilitates the extension of the grids for all layers in the network, irrespective of how those grids are defined. For instance, in the case of ChebyKAN, where a traditional grid concept doesn’t exist, the method extends the order of the Chebyshev polynomials utilized in the model.

Although KANs implemented in jaxKAN can be applied across a wide range of problem domains as a powerful alternative to Multilayer Perceptrons (MLPs), the package places a strong emphasis on their application in Physics-Informed Kolmogorov-Arnold Networks (PIKANs). To support this focus, jaxKAN includes specialized utilities and tutorials aimed at the task of solving forward or inverse PDE problems.

The source code for jaxKAN can be found in the jaxKAN GitHub Repository.

Research

If you have used jaxKAN in your research, we’d love to hear from you! Below, you can find a list of academic publications that have used jaxKAN.