faster NS algorithm (hybrid with 4 iterations) #10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is my attempt to improve the speed of the Newton-Schulz algorithm by making it converge with only four iterations.
I want to highlight that this approach changes the underlying algorithm. Extra verification may be desirable before merging the PR. Any tests and comments are welcome!
Changes
Fewer iterations:
We remove the previous normalization to switch to AOL rescaling
Which is further explained in the paper: https://arxiv.org/pdf/2208.03160
This consists of computing W@W^t using ns_line_1 and then computing the scaling factors:
fast_inv_sqrt(reduce_sum(abs(WW^t), axis=-1))which is a vector.Since the main operation to compute those corresponds to
ns_line_1, we can fuse it with the first Newton-Schulz iterate. Furthermore, this gives a better starting point for the Newton-Schulz iterations as the matrix is closer to orthogonal.Thanks to this, we can save one iteration of Newton-Schulz. However, the non-linear nature of AOL prevents the use Jiacheng's approach to computing new polynomial factors. So we rely on a genetic algorithm to optimize those (see https://github.com/thib-s/flash-newton-schulz](https://github.com/thib-s/flash-newton-schulz) ).
This is done in the file
opt_params.py, which can be run to find better polynomials.triton kernel for ns_line_3:
I noticed that the
ns_line_3function was readingXmultiple times, so I wrote a Triton kernel to avoid multiple loading of the same data. This gives a marginal speedup on small matrices, where loading data is the bottleneck. (It can be removed for increased code readability).Tests
Tests on the 160m training script do not show direct regression:
While very promising, I cannot test this at a larger scale. It would be great if someone could confirm the absence of regression at larger scales.
Current results:
Using a L40S GPU, we obtain a decent speedup:
When tested on random uniform matrices, the matrices seem closer to orthogonal:
Extra tests also showed stable results on heavy-tailed distributions (Levy).