Tutorial 4 - Classification Example

After working with regression examples in the first tutorials, we proceed with a classification example in this one.

[1]:
from jaxkan.models.KAN import KAN

import jax
import jax.numpy as jnp

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from flax import nnx
import optax

import matplotlib.pyplot as plt
import numpy as np

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
[ ]:

Data Generation

We will use the make_classification method of sklearn.datasets to generate some mock data for the classification problem.

[2]:
# Generate synthetic data
seed = 42

X, y = make_classification(
    n_samples=1000,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_clusters_per_class=1,
    class_sep=1.5,
    random_state=seed
)
[3]:
# Plot the generated data
plt.figure(figsize=(7, 4))
plt.scatter(X[y == 0][:, 0], X[y == 0][:, 1], label='Class 0', alpha=0.7, marker='x', color='#a3630f')
plt.scatter(X[y == 1][:, 0], X[y == 1][:, 1], label='Class 1', alpha=0.7, marker='o', color='#25599c')
plt.title("Synthetic Data for Binary Classification")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
../_images/tutorials_Tutorial_4_-_Classification_Example_7_0.png
[ ]:

Preprocessing

We split the data in train/test sets.

[4]:
y = y.reshape(-1, 1)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

print("Training set size:", X_train.shape)
print("Test set size:", X_test.shape)
Training set size: (800, 2)
Test set size: (200, 2)
[ ]:

KAN Model

We covered KAN Model selection in the first tutorial, so feel free to refer to it for more info. For this example, we will be using a Fourier KAN Layer.

[5]:
# Initialize a KAN model
n_in = X_train.shape[1]
n_out = y_train.shape[1]
n_hidden = 6

layer_dims = [n_in, n_hidden, n_hidden, n_out]
req_params = {'D': 10}

model = KAN(layer_dims = layer_dims,
            layer_type = 'fourier',
            required_parameters = req_params,
            seed = seed
           )

print(model)
KAN( # Param: 1,093 (4.4 KB)
  layer_type='fourier',
  layers=List([
    FourierLayer( # Param: 246 (984 B)
      n_in=2,
      n_out=6,
      D=10,
      bias=Param( # 6 (24 B)
        value=Array(shape=(6,), dtype=dtype('float32'))
      ),
      c_cos=Param( # 120 (480 B)
        value=Array(shape=(6, 2, 10), dtype=dtype('float32'))
      ),
      c_sin=Param( # 120 (480 B)
        value=Array(shape=(6, 2, 10), dtype=dtype('float32'))
      )
    ),
    FourierLayer( # Param: 726 (2.9 KB)
      n_in=6,
      n_out=6,
      D=10,
      bias=Param( # 6 (24 B)
        value=Array(shape=(6,), dtype=dtype('float32'))
      ),
      c_cos=Param( # 360 (1.4 KB)
        value=Array(shape=(6, 6, 10), dtype=dtype('float32'))
      ),
      c_sin=Param( # 360 (1.4 KB)
        value=Array(shape=(6, 6, 10), dtype=dtype('float32'))
      )
    ),
    FourierLayer( # Param: 121 (484 B)
      n_in=6,
      n_out=1,
      D=10,
      bias=Param( # 1 (4 B)
        value=Array([0.], dtype=float32)
      ),
      c_cos=Param( # 60 (240 B)
        value=Array(shape=(1, 6, 10), dtype=dtype('float32'))
      ),
      c_sin=Param( # 60 (240 B)
        value=Array(shape=(1, 6, 10), dtype=dtype('float32'))
      )
    )
  ])
)
[ ]:

Training

[6]:
opt_type = optax.adam(learning_rate=0.001)

optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)

In this case, the loss function is modified to correspond to a cross entropy loss term.

[7]:
# Define train loop
@nnx.jit
def train_step(model, optimizer, X_train, y_train):

    def loss_fn(model):

        logits = model(X_train)
        probs = nnx.sigmoid(logits)
        loss = jnp.mean(-y_train * jnp.log(probs + 1e-8) - (1 - y_train) * jnp.log(1 - probs + 1e-8))

        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)

    return loss
[8]:
# Initialize train_losses
num_epochs = 2000
train_losses = jnp.zeros((num_epochs,))

for epoch in range(num_epochs):
    # Calculate the loss
    loss = train_step(model, optimizer, X_train, y_train)

    # Append the loss
    train_losses = train_losses.at[epoch].set(loss)
[ ]:

Evaluation

[9]:
plt.figure(figsize=(7, 4))

plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='#25599c', markersize=1)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.yscale('log')

plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)

plt.show()
../_images/tutorials_Tutorial_4_-_Classification_Example_24_0.png

Now the model is evaluated in terms of its F1-Score.

[10]:
logits = model(X_test)
y_pred = np.array((nnx.sigmoid(logits) > 0.5).astype(int))
score = f1_score(y_pred, y_test)

print(f"The F1-Score of the fit is {100*score:.3f}%")
The F1-Score of the fit is 93.750%
[ ]: