Kolmogorov-Arnold Networks: the latest advance in Neural Networks, simply explained

Author:Murphy  |  View: 28119  |  Time: 2025-03-22 21:43:03
Comparison of the MLP and the KAN. Image from paper.

In April, a paper appeared on arXiv named: KAN: Kolmogorov–Arnold Networks. The tweet announcing got ~5k likes, which for a paper announcement is pretty viral. The accompanying GitHub repository already has 7.6k stars and counting.

The Kolmogorov-Arnold Network (KAN) is a brand-new class of Neural Network building block. It aims to be more expressive, less prone to overfitting and more interpretable than the Multi-Layer Perceptron (MLP). MLPs are ubiquitous in deep learning models. We know for example that they are used in between transformer blocks for models such as GPT-2, 3 and (probably) 4. Making an improvement on the MLP would have wide ramifications for the machine learning world.

MLPs in a nutshell

The MLP is actually a very old architecture that dates back to the 50s. The idea was to replicate the structure of the brain; with lots of interconnected neurons that feed information forward, hence the name feed-forward network.

A representation of an MLP from https://arxiv.org/pdf/2101.03541

MLPs are often shown with the diagrams like the image above. For laypeople, it is quite useful, but, in my opinion, it doesn't convey any deep understanding of what is really going on. It is much easier to represent it with mathematics.

Let's assume we have some inputs x, and some outputs y. A 2-layer MLP would be the following:

Where the _W_s are matrices of learnable weights, the _b_s are vectors of biases. The function f is a non-linear function. Looking at these equations, it becomes obvious that an MLP is a series of linear regression models with non-linearities in between each one. It is a very basic set-up.

Despite being basic, it is extremely expressive. There are mathematical guarantees that MLPs are universal approximators, i.e: they can approximate any function, similarly to how functions can all be represented with a Taylor series.

To train the weights of the model, we use backpropagation thanks to automatic differentiation (autodiff). I won't go much into it here, but it is important to note that autodiff can work on any differentiable function, this will be important later.

The Problem with MLPs

MLPs are used in a wide variety of use cases, but have some serious drawbacks.

  1. Because they are extremely flexible as a model, they can fit to any data very well. As a result, they are very likely to overfit.
  2. There are often a lot of weights in the model, it becomes very difficult to interpret these to draw conclusions from the data. We often say deep learning models are "black-boxes".
  3. Having lots of weights also means that they can be quite long to train, a large majority of GPT-3 parameters are in MLP layers.

Kolmogorov-Arnold Networks

Kolmogorov-Arnold Representation theorem

The Kolmogorov-Arnold Representation Theorem is similar in goal to the Universal Approximation theorem that underpins MLPs, but has a different premise. It essentially states that any multivariate function can be represented with addition of 1-dimensional non-linear functions. For example: the division operation of a vector v=(x1, x2) can be substituted with logarithms and an exponentiation:

Now, why exactly is this useful? What does this actually achieve?

This gives us a different but simple paradigm to start building a neural network architecture. The authors claim that this architecture is more interpretable, more parameter efficient and achieves better generalisability than using MLPs. In MLPs, the non-linear functions are fixed and never change at any point during training. In KANs, there are no more weight matrices or biases, just 1-dimensional non-linearities fitted to the data. These non-linearities are then added together. We can then stack more and more layers to create more complex functions.

B-splines

There is something important to note in the way non-linearities are represented in KANs. Unlike MLPs where they are explicitly defined as e.g. ReLU(), Tanh(), silu(), etc, in KANs, the authors use splines. These are essentially piecewise polynomials. They come from the field of computer graphics, where overparametrisation is not something to be concerned about.

Splines solve the problem of interpolating smoothly between multiple points. If you are familiar with Machine Learning theory, you will know that to interpolate perfectly between n data points, you need an order n-1 polynomial. The problem is higher order polynomials can become very wiggly and don't look smooth.

10 data points perfectly fitted by an order 9 polynomial. Image by author.

Splines tackle this issue by fitting piecewise polynomial functions to sections between data points. Here we use Cubic Splines.

The cubic spline interpolates much better, but can't generalise. Image by author.

For cubic splines (one of the types of splines), to ensure smoothness, constraints are set on the first and second derivatives at the location of the data points (or knots). The curve on either side of the data point must have matching first and second derivative at the data point.

KANs uses B-splines, another type of spline which has the property of being both local (moving a point does not affect the overall shape of the curve), and have matching second derivatives (also called C2 continuity). This comes at the cost of not actually passing through the points (except at the extremities).

3 B-splines for 5 data points. From Wikipedia. Notice how the curves do not pass through the data points.

In machine learning, especially when applied to physics, not passing through every data point is fine since we expect measurements to be noisy.

This is what is happening on every edge in the computational graph of KANs. The 1-dimensional data is fitted with a set of B-splines.

Into the KAN

So now we have a piecewise parametric curve on each edge of the computational graph. At each node, these get summed: we saw earlier that we can approximate any function by doing this.

KAN[2, 5, 1] (2 inputs, 5 hidden nodes nand 1 output). Image from paper.

To train such a model, we can use standard backpropagation. In this case the authors use LBFGS, which is a second-order optimisation method (compared to Adam which is first-order). Another detail to note: on each edge, representing a 1-dimensional function, there is the B-spline, but the authors also add a non-linearity: a silu function.

The explanation for this isn't clear, but it is most likely due to vanishing gradients (at least that's my guess).

Let's use it

I am going to use the code that the authors provide, it works fantastically well and there are lots of examples to understand it better.

They use synthetic data, which they generate from the following function:

Defining the model

model = KAN(width=[2,5,1], grid=5, k=3, seed=0)

This has 3 parameters defined:

  • The width, which is defined similarly to how MLPs are defined: a list where each element corresponds to a layer and the element value is the width of that layer. In this case, there are 3 layers; the input dimension is 2, there are 5 hidden dimensions and the output dimension is 1
  • The grid relates to the B-splines, it describes how fine-grained the grid between data points can be. Increasing this can create more wiggly functions.
  • k is the polynomial order of the B-splines, cubic is a good choice in general as cubic curves have nice properties for splines.
  • seed, random seed: the weights of splines are randomly initialised with Gaussian noise (just like in a regular MLP).

Training

model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.0)

The library's API is very intuitive, we can see that we are using the LBFGS optimiser, with 20 steps of training. The next two parameters are related to regularisation of the network.

The next step after training is to prune the model, this removes edges and nodes that are below a threshold of relevance, after doing this, it is recommended to retrain a bit. Then each spline edge is made into a symbolic function (log, exp, sin…). This is done either manually or automatically. The library provides a fantastic tool to see what is going on inside the model thanks to the model.plot()method.

# Code to fit symbolic functions to the fitted splines
if mode == "manual":
    # manual mode
    model.fix_symbolic(0, 0, 0, "sin")
    model.fix_symbolic(0, 1, 0, "x^2")
    model.fix_symbolic(1, 0, 0, "exp")
elif mode == "auto":
    # automatic mode
    lib = ["x", "x^2", "x^3", "x^4", "exp", "log", "sqrt", "sin", "abs"] 
    model.auto_symbolic(lib=lib)

Once a symbolic function is put in place at each edge, a final retraining happens to make sure the affine parameters of each edge are sensible.

This whole training process is summarised in the diagram below.

An example of how to do symbolic regression with KAN. Image from paper.

Full training code looks something like this:

# Define the model
model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0)
# First training
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.0)
# Prune edges that have low importance
model = model.prune() 
# Retrain the pruned model with no regularisation
model.train(dataset, opt="LBFGS", steps=50) 
# Find the symbolic functions
model.auto_symbolic(lib=["x", "x^2", "x^3", "x^4", "exp", "log", "sqrt", "sin", "abs"])
# Find the afine parameters of the fitted functions without regularisation
model.train(dataset, opt="LBFGS", steps=50) 
# Display the resultant equation
model.symbolic_formula()[0][0] # Print the resultant symbolic function

Some thoughts

Flakiness

There are quite a few hyperparameters that can be tuned in the model. These can give very different results. For example, in the example above: changing the number of hidden neurons from 5 to 6 means the KAN doesn't find the correct function.

Resulting function found by the KAN[2,6,1]. Image by author.

This flakiness is expected as the architecture is brand new. It took many decades for people to find the best way to tune MLP hyperparameters (such as learning rates, batch sizes, initialisations, etc).

Why not symbolic regression?

The goal of KANs is to compete with symbolic regression, as well as MLPs. Comparing KANs to the PySR library, we find that it can also find the correct functional form of the training data. On the other hand, it is even more prone to weirdness than KANs are. The authors also make this point in their paper. In my case, changing the random seed of the PySR model made the resultant equation a bit nonsensical. This was not the case with KANs which are less reliant on randomness.

You can find all my experiments on my GitHub repo.

Conclusion

The MLP has been around for a long time, it is long overdue for an upgrade. We know such a change is possible, around 6 years ago, LSTMs which were ubiquitous in sequence modelling, were dethroned by transformers as the standard language model architecture building block. It would be exciting to have this kind of change happen to MLPs. On the other hand, this architecture is still flaky and doesn't work fantastically well. Time will tell if the community finds a way to bypass this flakiness and unlock the true potential of KANs, or if the KAN gets forgotten and becomes a piece of machine learning trivia.

I am very excited by this new architecture, but I am also sceptical.

I do lean more on the excited side, as I'm in the AI for science world. The KAN was built for the kinds of tasks that I am interested in. Despite this, it's not unlikely that a few AI labs are already trying to replace the MLP layers in their LLMs with KANs.

References

https://arxiv.org/abs/2404.19756

https://en.wikipedia.org/wiki/Multilayer_perceptron

https://en.wikipedia.org/wiki/Universal_approximation_theorem

https://en.wikipedia.org/wiki/B-spline

https://www.youtube.com/watch?v=YMl25iCCRew&list=PLWfDJ5nla8UpwShx-lzLJqcp575fKpsSO

Tags: AI Editors Pick Kolmogorov Arnold Network Machine Learning Math

Comment