In a series of previous blogposts [1, 2, 3, 4] I ran some experiments drawing the boundaries of the polytopes generated by a fully-connected leaky ReLU network while it was getting trained on reproducing an input image.
As I tried to scale the experiments to larger networks, I noticed a dramatic slowdown in the code, caused by the calculation of a hash of the activation pattern happening on CPU -- so each training step would be fast, but then everything would grind to a halt for the visualisation, and for each pixel the code would forward-evaluate the NN (all in all 1024*1024 times), and whenever the prediction was calculated, it'd transfer the activation pattern to CPU and then perform the hashing. This was very slow, and very non-parallel.
I had contemplated writing some custom CUDA code to speed things up - there's no reason to store the activation pattern or transfer it, the "right" way to solve the problem is computing a hash on the fly, ideally a hash with a commutative update function so the order in which the different ReLU neurons update the hash doesn't matter.
Then again, this is a hobby project, and I don't have the time to do anything overly smart for the moment. So I decided to - before doing anything sophisticated - I'll see if I can have one of the two existing coding assistant that I use regularly solve the problem for me.
So I created two different directories, checked out the same base repo into both, created branches in both, and then queried both Gemini CLI and Claude Code perform the task, using the following prompt:
The Python script in this directory trains a fully connected leaky ReLU network on an input image and tries to reproduce it. It also draws pictures illustrating the boundaries of the polytopes generated by the creases that the ReLU creates in input space. Unfortunately, the code to generate the polytope visualisation is slow, because it involves 1024*1024 evaluations of the NN forward, and then it needs to hash the activation pattern into a hash to uniquely identify what polytope the pixel resides on. I would like to speed up this computation, by - instead of calculating a hash of the activation pattern at the end - somehow embedding the calculation of a hash into the forward pass on-GPU. This might be doable with PyTorch hooks, but I don't know precisely. What I do know is that if I run ``` python3 ./draw-poly-while-training.py --input ./centered_ring.png --shape [100]*20 --epochs 30 --seed 12345678 --points 5050 --save-interval 10 ``` the output looks something like this: ``` (...) Input size (MB): 0.01 Forward/backward pass size (MB): 16.39 Params size (MB): 0.77 Estimated Total Size (MB): 17.17 ========================================================================================== 2025-07-08 15:15:25,811 - polytope_nn - INFO - Epoch 1/2000000 - Train Loss: 3.315190, Val Loss: 0.329414 2025-07-08 15:15:25,857 - polytope_nn - INFO - Epoch 2/2000000 - Train Loss: 1.045730, Val Loss: 0.065818 2025-07-08 15:15:25,901 - polytope_nn - INFO - Epoch 3/2000000 - Train Loss: 1.414065, Val Loss: 0.488735 2025-07-08 15:15:25,948 - polytope_nn - INFO - Epoch 4/2000000 - Train Loss: 0.201550, Val Loss: 0.102159 2025-07-08 15:15:26,100 - polytope_nn - INFO - Epoch 5/2000000 - Train Loss: 0.198983, Val Loss: 0.050712 2025-07-08 15:15:26,145 - polytope_nn - INFO - Epoch 6/2000000 - Train Loss: 0.255710, Val Loss: 0.060731 2025-07-08 15:15:26,189 - polytope_nn - INFO - Epoch 7/2000000 - Train Loss: 0.122960, Val Loss: 0.091274 2025-07-08 15:15:26,232 - polytope_nn - INFO - Epoch 8/2000000 - Train Loss: 0.180629, Val Loss: 0.053913 2025-07-08 15:15:26,276 - polytope_nn - INFO - Epoch 9/2000000 - Train Loss: 0.826762, Val Loss: 0.156673 2025-07-08 15:15:26,320 - polytope_nn - INFO - Epoch 10/2000000 - Train Loss: 0.211313, Val Loss: 0.117810 2025-07-08 15:16:27,853 - polytope_nn - INFO - Visualization @ epoch 10: 61.53s 2025-07-08 15:16:27,899 - polytope_nn - INFO - Epoch 11/2000000 - Train Loss: 0.174978, Val Loss: 0.053103 2025-07-08 15:16:27,943 - polytope_nn - INFO - Epoch 12/2000000 - Train Loss: 0.332561, Val Loss: 0.095801 2025-07-08 15:16:27,987 - polytope_nn - INFO - Epoch 13/2000000 - Train Loss: 0.192859, Val Loss: 0.064341 2025-07-08 15:16:28,031 - polytope_nn - INFO - Epoch 14/2000000 - Train Loss: 0.115424, Val Loss: 0.051763 2025-07-08 15:16:28,076 - polytope_nn - INFO - Epoch 15/2000000 - Train Loss: 0.362009, Val Loss: 0.128609 2025-07-08 15:16:28,122 - polytope_nn - INFO - Epoch 16/2000000 - Train Loss: 0.117143, Val Loss: 0.058641 2025-07-08 15:16:28,165 - polytope_nn - INFO - Epoch 17/2000000 - Train Loss: 0.335812, Val Loss: 0.082517 2025-07-08 15:16:28,211 - polytope_nn - INFO - Epoch 18/2000000 - Train Loss: 0.079342, Val Loss: 0.060753 2025-07-08 15:16:28,257 - polytope_nn - INFO - Epoch 19/2000000 - Train Loss: 0.104123, Val Loss: 0.047914 2025-07-08 15:16:28,304 - polytope_nn - INFO - Epoch 20/2000000 - Train Loss: 0.097466, Val Loss: 0.050452 2025-07-08 15:17:31,553 - polytope_nn - INFO - Visualization @ epoch 20: 63.25s ``` From this we can see that a single visualisation step takes more than a minute for a network of this size, and profiling shows that most of this time is spent in hashing things on the CPU, not the GPU. I would like you to find a way to do the calculation of the hash during the forward pass on the GPU, ideally without storing the activation vector in memory, and instead having a hash function that can be updated commutatively so each ReLU unit can update the final hash while it calculates the forward pass. I want you to: 1) Create a plausible plan for improving and speeding up the code. 2) Implement that plan. 3) Re-run the script with the specified command line, and observe if a speedup indeed took place -- e.g. check that (a) the visualisation was sped up and (b) the sum of 10 training steps and the visualisation together was sped up. It is frightfully easy to speed up the visualisation step but slow down the training steps so much that 10 training steps and 1 visualisation step get *slower*. Please also verify that the image output is the same between the pre-change and post-change version, to ensure that the changes do not break anything.
I then allowed both models to churn for a while. Both models provided changes, but Gemini failed to actually verify that the results are the same. Claude one-shotted the problem; Gemini needed the following additional prompt:
I have run your example code, and checked the output. The output images are not identical between the pre-change and post-change version, and even the training loss changed. FWIW, none of the polytopes are visible in your version. Could you re-check your work, and this time make sure you check whether the outputs are the same?
With that extra prodding / prompting, the solution provided by the model worked flawlessly, and was even a tiny bit faster than the Claude version.
Let's look at the code that both models generated: The Gemini branch and the Claude branch. Reading the changes, a few things become clear:
- Gemini shot itself in the foot on the RNG by generating a bunch of random hash coefficients, and that messed up the state of the RNG, so the training runs were no longer comparable pre/post change.
- Gemini is using torch.matmul for the hash computation, whereas Claude is computing the hash as torch.sum( A * B ).
- Claude has broken up the code in more smaller functions, whereas Gemini didn't. Claude's code is mildly more readable, Gemini's is the more minimal change.
When training diverges (for larger and deeper nets), the divergence starts by first messing up the linear functions, and only after they are gloriously messed up, the geometry of the polytopes starts to go haywire, too.
Until then!
No comments:
Post a Comment