Skip to content
3 changes: 3 additions & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ Models
PirateNet <model/pirate_network.rst>
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
SINDy <model/sindy.rst>
Vectorized Spline <model/vectorized_spline.rst>
Kolmogorov-Arnold Network <model/kolmogorov_arnold_network.rst>

Blocks
-------------
Expand All @@ -128,6 +130,7 @@ Blocks
Continuous Convolution Block <model/block/convolution.rst>
Orthogonal Block <model/block/orthogonal.rst>
PirateNet Block <model/block/pirate_network_block.rst>
KAN Block <model/block/kan_block.rst>

Message Passing
-------------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/model/block/kan_block.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
KANBlock
=======================
.. currentmodule:: pina.model.block.kan_block

.. autoclass:: pina._src.model.block.kan_block.KANBlock
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/model/kolmogorov_arnold_network.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
KolmogorovArnoldNetwork
===========================
.. currentmodule:: pina.model.kolmogorov_arnold_network

.. autoclass:: pina._src.model.kolmogorov_arnold_network.KolmogorovArnoldNetwork
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/model/vectorized_spline.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
VectorizedSpline
=======================
.. currentmodule:: pina.model.vectorized_spline

.. autoclass:: pina._src.model.vectorized_spline.VectorizedSpline
:members:
:show-inheritance:
158 changes: 158 additions & 0 deletions pina/_src/model/block/kan_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Module for the Kolmogorov-Arnold Network block."""

import torch
from pina._src.model.vectorized_spline import VectorizedSpline
from pina._src.core.utils import check_consistency, check_positive_integer


class KANBlock(torch.nn.Module):
"""
The inner block of the Kolmogorov-Arnold Network (KAN).

The block applies a spline transformation to the input, optionally combined
with a linear transformation of a base activation function. The output is
aggregated across input dimensions to produce the final output.

.. seealso::

**Original reference**:
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
Hou T., Tegmark M. (2025).
*KAN: Kolmogorov-Arnold Networks*.
DOI: `arXiv preprint arXiv:2404.19756.
<https://arxiv.org/abs/2404.19756>`_
"""

def __init__(
self,
input_dimensions,
output_dimensions,
spline_order=3,
n_knots=10,
grid_range=[0, 1],
base_function=torch.nn.SiLU,
use_base_linear=True,
use_bias=True,
init_scale_spline=1e-2,
init_scale_base=1.0,
):
"""
Initialization of the :class:`KANBlock` class.

:param int input_dimensions: The number of input features.
:param int output_dimensions: The number of output features.
:param int spline_order: The order of each spline basis function.
Default is 3 (cubic splines).
:param int n_knots: The number of knots for each spline basis function.
Default is 10.
:param grid_range: The range for the spline knots. It must be either a
list or a tuple of the form [min, max]. Default is [0, 1].
:type grid_range: list | tuple.
:param torch.nn.Module base_function: The base activation function to be
applied to the input before the linear transformation. Default is
:class:`torch.nn.SiLU`.
:param bool use_base_linear: Whether to include a linear transformation
of the base function output. Default is True.
:param bool use_bias: Whether to include a bias term in the output.
Default is True.
:param init_scale_spline: The scale for initializing each spline
control points. Default is 1e-2.
:type init_scale_spline: float | int.
:param init_scale_base: The scale for initializing the base linear
weights. Default is 1.0.
:type init_scale_base: float | int.
:raises ValueError: If ``grid_range`` is not of length 2.
"""
super().__init__()

# Check consistency
check_consistency(base_function, torch.nn.Module, subclass=True)
check_positive_integer(input_dimensions, strict=True)
check_positive_integer(output_dimensions, strict=True)
check_positive_integer(spline_order, strict=True)
check_positive_integer(n_knots, strict=True)
check_consistency(use_base_linear, bool)
check_consistency(use_bias, bool)
check_consistency(init_scale_spline, (int, float))
check_consistency(init_scale_base, (int, float))
check_consistency(grid_range, (int, float))

# Raise error if grid_range is not valid
if len(grid_range) != 2:
raise ValueError("Grid must be a list or tuple with two elements.")

# Knots for the spline basis functions
initial_knots = torch.ones(spline_order) * grid_range[0]
final_knots = torch.ones(spline_order) * grid_range[1]

# Number of internal knots
n_internal = max(0, n_knots - 2 * spline_order)

# Internal knots are uniformly spaced in the grid range
internal_knots = torch.linspace(
grid_range[0], grid_range[1], n_internal + 2
)[1:-1]

# Define the knots
knots = torch.cat((initial_knots, internal_knots, final_knots))
knots = knots.unsqueeze(0).repeat(input_dimensions, 1)

# Define the control points for the spline basis functions
control_points = (
torch.randn(
input_dimensions,
output_dimensions,
knots.shape[-1] - spline_order,
)
* init_scale_spline
)

# Define the vectorized spline module
self.spline = VectorizedSpline(
order=spline_order, knots=knots, control_points=control_points
)

# Initialize the base function
self.base_function = base_function()

# Initialize the base linear weights if needed
if use_base_linear:
self.base_weight = torch.nn.Parameter(
torch.randn(output_dimensions, input_dimensions)
* (init_scale_base / (input_dimensions**0.5))
)
else:
self.register_parameter("base_weight", None)

# Initialize the bias term if needed
if use_bias:
self.bias = torch.nn.Parameter(torch.zeros(output_dimensions))
else:
self.register_parameter("bias", None)

def forward(self, x):
"""
Forward pass of the Kolmogorov-Arnold block. The input is passed through
the spline transformation, optionally combined with a linear
transformation of the base function output, and then aggregated across
input dimensions to produce the final output.

:param x: The input tensor for the model.
:type x: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
y = self.spline(x)

if self.base_weight is not None:
base_x = self.base_function(x)
base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight)
y = y + base_out

# aggregate contributions from all input dimensions
y = y.sum(dim=1)

if self.bias is not None:
y = y + self.bias

return y
105 changes: 105 additions & 0 deletions pina/_src/model/kolmogorov_arnold_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
from pina._src.model.block.kan_block import KANBlock
from pina._src.core.utils import check_consistency


class KolmogorovArnoldNetwork(torch.nn.Module):
"""
Implementation of Kolmogorov-Arnold Network (KAN).

The model consists of a sequence of KAN blocks, where each block applies a
spline transformation to the input, optionally combined with a linear
transformation of a base activation function.

.. seealso::

**Original reference**:
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
Hou T., Tegmark M. (2025).
*KAN: Kolmogorov-Arnold Networks*.
DOI: `arXiv preprint arXiv:2404.19756.
<https://arxiv.org/abs/2404.19756>`_
"""

def __init__(
self,
layers,
spline_order=3,
n_knots=10,
grid_range=[-1, 1],
base_function=torch.nn.SiLU,
use_base_linear=True,
use_bias=True,
init_scale_spline=1e-2,
init_scale_base=1.0,
):
"""
Initialization of the :class:`KolmogorovArnoldNetwork` class.

:param layers: A list of integers specifying the sizes of each layer,
including input and output dimensions.
:type layers: list | tuple.
:param int spline_order: The order of each spline basis function.
Default is 3 (cubic splines).
:param int n_knots: The number of knots for each spline basis function.
Default is 3.
:param grid_range: The range for the spline knots. It must be either a
list or a tuple of the form [min, max]. Default is [0, 1].
:type grid_range: list | tuple.
:param torch.nn.Module base_function: The base activation function to be
applied to the input before the linear transformation. Default is
:class:`torch.nn.SiLU`.
:param bool use_base_linear: Whether to include a linear transformation
of the base function output. Default is True.
:param bool use_bias: Whether to include a bias term in the output.
Default is True.
:param init_scale_spline: The scale for initializing each spline
control points. Default is 1e-2.
:type init_scale_spline: float | int.
:param init_scale_base: The scale for initializing the base linear
weights. Default is 1.0.
:type init_scale_base: float | int.
:raises ValueError: If ``grid_range`` is not of length 2.
"""
super().__init__()

# Check consistency -- all other checks are performed in KANBlock
check_consistency(layers, int)
if len(layers) < 2:
raise ValueError(
"`Provide at least two elements for layers (input and output)."
)

# Initialize KAN blocks
self.kan_layers = torch.nn.ModuleList(
[
KANBlock(
input_dimensions=layers[i],
output_dimensions=layers[i + 1],
spline_order=spline_order,
n_knots=n_knots,
grid_range=grid_range,
base_function=base_function,
use_base_linear=use_base_linear,
use_bias=use_bias,
init_scale_spline=init_scale_spline,
init_scale_base=init_scale_base,
)
for i in range(len(layers) - 1)
]
)

def forward(self, x):
"""
Forward pass of the KolmogorovArnoldNetwork model. It passes the input
through each KAN block in the network and returns the final output.

:param x: The input tensor for the model.
:type x: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
for layer in self.kan_layers:
x = layer(x)

return x
Loading
Loading