Skip to content

Conversation

@thib-s
Copy link
Contributor

@thib-s thib-s commented Sep 1, 2025

This is my attempt to improve the speed of the Newton-Schulz algorithm by making it converge with only four iterations.

I want to highlight that this approach changes the underlying algorithm. Extra verification may be desirable before merging the PR. Any tests and comments are welcome!

Changes

Fewer iterations:

We remove the previous normalization to switch to AOL rescaling
Which is further explained in the paper: https://arxiv.org/pdf/2208.03160

This consists of computing W@W^t using ns_line_1 and then computing the scaling factors: fast_inv_sqrt(reduce_sum(abs(WW^t), axis=-1)) which is a vector.

Since the main operation to compute those corresponds to ns_line_1, we can fuse it with the first Newton-Schulz iterate. Furthermore, this gives a better starting point for the Newton-Schulz iterations as the matrix is closer to orthogonal.

Thanks to this, we can save one iteration of Newton-Schulz. However, the non-linear nature of AOL prevents the use Jiacheng's approach to computing new polynomial factors. So we rely on a genetic algorithm to optimize those (see https://github.com/thib-s/flash-newton-schulz](https://github.com/thib-s/flash-newton-schulz) ).
This is done in the file opt_params.py, which can be run to find better polynomials.

triton kernel for ns_line_3:

I noticed that the ns_line_3 function was reading X multiple times, so I wrote a Triton kernel to avoid multiple loading of the same data. This gives a marginal speedup on small matrices, where loading data is the bottleneck. (It can be removed for increased code readability).

Tests

Tests on the 160m training script do not show direct regression:

Capture d'écran 2025-09-25 155330

While very promising, I cannot test this at a larger scale. It would be great if someone could confirm the absence of regression at larger scales.

Current results:

Using a L40S GPU, we obtain a decent speedup:

speedup graph

When tested on random uniform matrices, the matrices seem closer to orthogonal:

orthogonality graph

Extra tests also showed stable results on heavy-tailed distributions (Levy).

@thib-s thib-s marked this pull request as ready for review September 25, 2025 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant