Saturday, April 05, 2025

Some experiments to help me understand Neural Nets better, post 2 of N

In this post, I will explain my current thinking about neural networks. In a previous post I explained the intuition behind my "origami view of NNs" (also called the "polytope lens" in some circles). In this post, I will go a little bit into the mathematical details of this.

The standard textbook explanation of a layer of a neural network looks something like this: 

σ(Wx+b)

where σ:RR is a nonlinearity (either the sigmoid or the ReLU or something like it), W is the matrix of weights attached to the edges coming into the neurons, and b is the vector of "biases". Personally, I find this notation somewhat cumbersome, and I prefer to pull the bias vector into the weight matrices, so that I can think of an NN as "matrix multiplications alternating with applying a nonlinearity".

I really don't like to think about NNs with nonlinearities other than ReLU and leaky ReLU - perhaps over time I will have to accept that these are a thing, but for now all NNs that I think about are either ReLU or leaky ReLU. For the moment, we also assume that the network outputs a real vector in the end, so it is not (yet) a classifier.

Assume we have a network with k layers, and the number of neurons in each layer are n1,,nk. The network maps between real vector spaces (or an approximation thereof) of dimension i and o.
NN:RiRo
I would like to begin by pulling the bias vector into the matrix multiplications, because it greatly simplifies notation. So the input vector x gets augmented by appending a 1, and the bias vector b gets appended to W:
W=[Wb],x=[x1]
Instead of σ(overlineWx+b) we can write σ(Wx).
In our case, σ is always ReLU or leaky ReLU, so a "1" will be mapped to a "1" again. For reasons of being able to compose things nicely later, I would also like the output of σ(Wx) to have a 1 as last component, like our input vector x. To achieve this, I need to append a row of all zeroes terminated in a 1 to W. Finally we have:
W=[Wb0,1],x=[x1]
The previous post explained why the NN divides the input space into polytopes on which the approximated function will be entirely linear. Consider the data point x1. If you evaluate the NN on x1, a few of the ReLUs will light up (because their incoming data sums to more than 0) and a few will not. For a given x1, there will be k boolean vectors representing the activation (or non-activation) of each ReLU in the NN. Which means we have a function which for a given input vector, layer, and neuron number in the layer returns either 0 or 1 in the ReLU case, or 0.01 and 1 in the leaky ReLU case.

We call this function a. We could make it a function with three arguments (layer, neuron index, input vector), but I prefer to move the layer and the neuron index into indices, so we have:
al,n:Ri{0,1} for ReLU 
and
al,n:Ri{0.01,1} for leaky ReLU 
This gives us a very linear-algebra-ish expression for the entire network:
NN(x)=W1A1WkAkx=i=0k(WiAi)x
Where the Ak are of the form
Ak=(ak,1(x)000ak,nk(x)00001)
So we can see now very clearly that the moment that the activation pattern is determined, the entire function becomes linear, and just a series of matrix multiplications where every 2nd matrix is a diagonal matrix with the image of the activation pattern on the diagonal.

This representation shows us that the function remains identical (and linear) provided the activation pattern does not change - points on the same polytope will have an identical activation pattern, and we can hence use the activation pattern as a "polytope identifier" -- for any input point x I can run it through the network, and if a second point x has the same pattern, I know it lives on the same polytope.

So from this I can take the sort of movies for single-layer NNs that were created in part 1 - where we can take an arbitrary 2-dimensional image as the unknown distribution that we wish to learn and then visualize the training dynamics: Show how the input space is cut up into different polytopes on which the function is then linearly approximated, and show how this partition and approximation evolves through the training process for differently-shaped networks.

We take input images of size 1024x1024, so one megabyte of byte-sized values, and sample 5000 data points from them - a small fraction, about 0.4% of the overall points in the image. We specify a shape for the MLP, and train it for 6000 steps, visualizing progress.

For simplicity, we try to learn a black ring on white ground, with sharply-delineated edges - first with a network that has 14 neurons per layer, and is 6 layers deep. 

On the left-hand side, we see the evaluated NN with the boundaries of the polytopes that it has generated to split the input space. In the center, we only see the output of the NN - what the NN has "learnt" to reproduce so far. And on the right hand side we see the original image, with the tiny, barely perceptible red dots the 5000 training points, and the blue dots a validation set of 1000 points. 

Here is a movie of the dynamics of the training run:

This is pretty neat, how about a differently-shaped NN? What happens if we force the NN through a 2-neuron bottleneck during the training process?
This last network has 10 layers of 10 neurons, then one layer of 2 neurons, then another 3 layers of 10 neurons. By number of parameters it is vaguely comparable to the other network, but it exhibits noticeably different training dynamics.

What happens if we dramatically overparametrize a network? Will it overfit our underlying data, and find a way to carve up the input space to reduce the error on the training set without reproducing a circle?

Let's try - how about a network with 20 neurons, 40 layers deep? That should use something like 20k floating point parameters in order to learn 5000 data points, so perhaps it will overfit?
Turns out this example doesn't, but it offers particularly rich dynamics as we watch it: Around epoch 1000 we can see how the network seems to have the general shape of the circle figured out, and most polytope boundaries seem to migrate to this circle. The network wobbles a bit but seems to make headway. By epoch 2000 we think we have seen it all, and the network will just consolidate around the circle. Between epoch 3000 and 4000 something breaks, loss skyrockets, and it seems like the network is disintegrating and training is diverging. By epoch 4000 it has re-stabilized, but in a very different configuration for the input space partition. This video ends around epoch 5500.

This is quite fascinating. There is no sign of overfitting, but we can see how the as the network gets deeper, training gets less stable: The circle seems to wobble much more, and we have these strange catastrophic-seeming phase changes after which the network has to re-stabilize. It also appears as if the network accurately captures the "circle" shape in spite of having only relatively few data points and more than enough capacity to overfit on them.

I will keep digging into this whenever time permits, I hope this was entertaining and/or informative. My next quest will be building a tool that - for a given point in input space - extracts a system of linear inequations that describe the polytope that this point lives on. Please do not hesitate to reach out if you ever wish to discuss any of this!

No comments: