Friday, July 11, 2025

Understand Neural Nets better, post 5 of N -- Code Assistant shootout

In a series of previous blogposts [1, 23, 4] I ran some experiments drawing the boundaries of the polytopes generated by a fully-connected leaky ReLU network while it was getting trained on reproducing an input image.

As I tried to scale the experiments to larger networks, I noticed a dramatic slowdown in the code, caused by the calculation of a hash of the activation pattern happening on CPU -- so each training step would be fast, but then everything would grind to a halt for the visualisation, and for each pixel the code would forward-evaluate the NN (all in all 1024*1024 times), and whenever the prediction was calculated, it'd transfer the activation pattern to CPU and then perform the hashing. This was very slow, and very non-parallel.

I had contemplated writing some custom CUDA code to speed things up - there's no reason to store the activation pattern or transfer it, the "right" way to solve the problem is computing a hash on the fly, ideally a hash with a commutative update function so the order in which the different ReLU neurons update the hash doesn't matter.

Then again, this is a hobby project, and I don't have the time to do anything overly smart for the moment. So I decided to - before doing anything sophisticated - I'll see if I can have one of the two existing coding assistant that I use regularly solve the problem for me.

So I created two different directories, checked out the same base repo into both, created branches in both, and then queried both Gemini CLI and Claude Code perform the task, using the following prompt:

The Python script in this directory trains a fully connected leaky ReLU network on an input image and tries 
to reproduce it. It also draws pictures illustrating the boundaries of the polytopes generated by the creases
that the ReLU creates in input space. Unfortunately, the code to generate the polytope visualisation is slow,
because it involves 1024*1024 evaluations of the NN forward, and then it needs to hash the activation pattern
into a hash to uniquely identify what polytope the pixel resides on.

I would like to speed up this computation, by - instead of calculating a hash of the activation pattern at the 
end - somehow embedding the calculation of a hash into the forward pass on-GPU. This might be doable with 
PyTorch hooks, but I don't know precisely. 

What I do know is that if I run 
```
python3 ./draw-poly-while-training.py  --input ./centered_ring.png --shape [100]*20 --epochs 30 --seed 12345678 --points 5050 --save-interval 10
``` 

the output looks something like this: 
```
(...)
Input size (MB): 0.01
Forward/backward pass size (MB): 16.39
Params size (MB): 0.77
Estimated Total Size (MB): 17.17
==========================================================================================
2025-07-08 15:15:25,811 - polytope_nn - INFO - Epoch 1/2000000 - Train Loss: 3.315190, Val Loss: 0.329414
2025-07-08 15:15:25,857 - polytope_nn - INFO - Epoch 2/2000000 - Train Loss: 1.045730, Val Loss: 0.065818
2025-07-08 15:15:25,901 - polytope_nn - INFO - Epoch 3/2000000 - Train Loss: 1.414065, Val Loss: 0.488735
2025-07-08 15:15:25,948 - polytope_nn - INFO - Epoch 4/2000000 - Train Loss: 0.201550, Val Loss: 0.102159
2025-07-08 15:15:26,100 - polytope_nn - INFO - Epoch 5/2000000 - Train Loss: 0.198983, Val Loss: 0.050712
2025-07-08 15:15:26,145 - polytope_nn - INFO - Epoch 6/2000000 - Train Loss: 0.255710, Val Loss: 0.060731
2025-07-08 15:15:26,189 - polytope_nn - INFO - Epoch 7/2000000 - Train Loss: 0.122960, Val Loss: 0.091274
2025-07-08 15:15:26,232 - polytope_nn - INFO - Epoch 8/2000000 - Train Loss: 0.180629, Val Loss: 0.053913
2025-07-08 15:15:26,276 - polytope_nn - INFO - Epoch 9/2000000 - Train Loss: 0.826762, Val Loss: 0.156673
2025-07-08 15:15:26,320 - polytope_nn - INFO - Epoch 10/2000000 - Train Loss: 0.211313, Val Loss: 0.117810
2025-07-08 15:16:27,853 - polytope_nn - INFO - Visualization @ epoch 10: 61.53s
2025-07-08 15:16:27,899 - polytope_nn - INFO - Epoch 11/2000000 - Train Loss: 0.174978, Val Loss: 0.053103
2025-07-08 15:16:27,943 - polytope_nn - INFO - Epoch 12/2000000 - Train Loss: 0.332561, Val Loss: 0.095801
2025-07-08 15:16:27,987 - polytope_nn - INFO - Epoch 13/2000000 - Train Loss: 0.192859, Val Loss: 0.064341
2025-07-08 15:16:28,031 - polytope_nn - INFO - Epoch 14/2000000 - Train Loss: 0.115424, Val Loss: 0.051763
2025-07-08 15:16:28,076 - polytope_nn - INFO - Epoch 15/2000000 - Train Loss: 0.362009, Val Loss: 0.128609
2025-07-08 15:16:28,122 - polytope_nn - INFO - Epoch 16/2000000 - Train Loss: 0.117143, Val Loss: 0.058641
2025-07-08 15:16:28,165 - polytope_nn - INFO - Epoch 17/2000000 - Train Loss: 0.335812, Val Loss: 0.082517
2025-07-08 15:16:28,211 - polytope_nn - INFO - Epoch 18/2000000 - Train Loss: 0.079342, Val Loss: 0.060753
2025-07-08 15:16:28,257 - polytope_nn - INFO - Epoch 19/2000000 - Train Loss: 0.104123, Val Loss: 0.047914
2025-07-08 15:16:28,304 - polytope_nn - INFO - Epoch 20/2000000 - Train Loss: 0.097466, Val Loss: 0.050452
2025-07-08 15:17:31,553 - polytope_nn - INFO - Visualization @ epoch 20: 63.25s
```

From this we can see that a single visualisation step takes more than a minute for a network of this size, and 
profiling shows that most of this time is spent in hashing things on the CPU, not the GPU.
I would like you to find a way to do the calculation of the hash during the forward pass on the GPU, ideally 
without storing the activation vector in memory, and instead having a hash function that can be updated
commutatively so each ReLU unit can update the final hash while it calculates the forward pass.

I want you to:

1) Create a plausible plan for improving and speeding up the code.
2) Implement that plan.
3) Re-run the script with the specified command line, and observe if a speedup indeed took place -- e.g. check
that (a) the visualisation was sped up and (b) the sum of 10 training steps and the visualisation together was
sped up.

It is frightfully easy to speed up the visualisation step but slow down the training steps so much that 10
training steps and 1 visualisation step get *slower*.

Please also verify that the image output is the same between the pre-change and post-change version, to ensure
that the changes do not break anything.

I then allowed both models to churn for a while. Both models provided changes, but Gemini failed to actually verify that the results are the same. Claude one-shotted the problem; Gemini needed the following additional prompt:

I have run your example code, and checked the output. The output images are not identical between the
pre-change and post-change version, and even the training loss changed. FWIW, none of the polytopes
are visible in your version. Could you re-check your work, and this time make sure you check whether
the outputs are the same?

With that extra prodding / prompting, the solution provided by the model worked flawlessly, and was even a tiny bit faster than the Claude version.

Let's look at the code that both models generated: The Gemini branch and the Claude branch. Reading the changes, a few things become clear:

  1. Gemini shot itself in the foot on the RNG by generating a bunch of random hash coefficients, and that messed up the state of the RNG, so the training runs were no longer comparable pre/post change.
  2. Gemini is using torch.matmul for the hash computation, whereas Claude is computing the hash as torch.sum( A * B ).
  3. Claude has broken up the code in more smaller functions, whereas Gemini didn't. Claude's code is mildly more readable, Gemini's is the more minimal change.
Interesting stuff. Neither solution is quite what I had in mind, but they are good enough for the moment, and provide a pretty significant speedup over the (also vibe-coded) stuff that I started out with. This is the first time for me that a coding assistant helped me optimize code in a nontrivial manner, and that's ... certainly something.

Anyhow, with these optimizations I can now run my data visualisation movie generation on slightly larger NNs with millions of parameters, so more studying ahead. I now need to figure out how to upload YouTube videos programmatically, but in the meantime, here is a video of training a 100-neuron, 10 layer deep network on the "circle drawing" task from my previous posts. Vibe coding randomly changed the color of my lines, but hey, that's ok.

As per usual, there are more questions than answers in this video. The thing that puzzles me most is the relative "instability" of the training in later epochs. This is visible in "flickers" where seemingly randomly the SGD step hits on a vastly higher loss, with parts of the screen turning black and loss spiking, and then the training needs to recover. Interestingly, the geometry of the polytopes doesn't change a lot in these situations, but the linear function on many of them changes at once, in a way that is very detrimental to overall performance. Once programmatic uploading works, I'll upload many more videos, because one of the intriguing observations I have is the following:

When training diverges (for larger and deeper nets), the divergence starts by first messing up the linear functions, and only after they are gloriously messed up, the geometry of the polytopes starts to go haywire, too.

Until then!






No comments: