From 7ea93fb3a1eacd5ec1f735f93f506d0ee6d94564 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Mon, 4 Aug 2025 14:27:48 +0200 Subject: [PATCH 1/3] Added setup.py, imports are "from optimizers import " --- README.md | 9 ++- optimizers/__init__.py | 6 ++ optimizers/dion.py | 14 ++--- optimizers/dion_reference.py | 8 +-- requirements_dev.txt | 4 ++ requirements_dion.txt | 3 + requirements.txt => requirements_train.txt | 4 -- setup.py | 72 ++++++++++++++++++++++ 8 files changed, 103 insertions(+), 17 deletions(-) create mode 100644 optimizers/__init__.py create mode 100644 requirements_dev.txt create mode 100644 requirements_dion.txt rename requirements.txt => requirements_train.txt (59%) create mode 100644 setup.py diff --git a/README.md b/README.md index c7d7aa3..4394c76 100644 --- a/README.md +++ b/README.md @@ -41,9 +41,14 @@ This code is written for modern PyTorch (version 2.7 or newer) using DTensor-bas ## Quick Start -Install dependencies: +Install dependencies for Dion and training script: ```bash -pip install -r requirements.txt +pip install -e .[train] +``` + +Optimizers can also be installed in a standalone mode without the training script: +```bash +pip install git+https://github.com/microsoft/dion.git ``` Download pretokenized FineWeb dataset: diff --git a/optimizers/__init__.py b/optimizers/__init__.py new file mode 100644 index 0000000..5051e41 --- /dev/null +++ b/optimizers/__init__.py @@ -0,0 +1,6 @@ +from .dion import Dion +from .dion import DionMixedPrecisionConfig +from .dion_simple import Dion as DionSimple +from .dion_reference import Dion as DionReference +from .muon import Muon +from .muon_reference import Muon as MuonReference \ No newline at end of file diff --git a/optimizers/dion.py b/optimizers/dion.py index fd1a8b0..e301155 100644 --- a/optimizers/dion.py +++ b/optimizers/dion.py @@ -31,7 +31,7 @@ @dataclass -class DionParamConfig: +class _DionParamConfig: """ Per-parameter configuration for Dion optimizer. """ @@ -193,7 +193,7 @@ def __init__( # This is intentionally not in self.state so it doesn't get checkpointed # State here may change upon resharding a checkpoint, so we recompute it - self._param_config: Dict[Tensor, DionParamConfig] = {} + self._param_config: Dict[Tensor, _DionParamConfig] = {} self._replicate_mesh = replicate_mesh self._outer_shard_mesh = outer_shard_mesh @@ -495,7 +495,7 @@ def _get_or_initialize_state(self, param: Tensor, group: dict) -> dict: raise ValueError(f"Unknown algorithm: {algo}") return state - def _get_dion_param_config(self, x: Tensor) -> DionParamConfig: + def _get_dion_param_config(self, x: Tensor) -> _DionParamConfig: """ Get the Dion-specific parameter configuration for a given tensor. If the configuration is not already initialized, it will be created. @@ -526,7 +526,7 @@ def _get_dion_param_config(self, x: Tensor) -> DionParamConfig: ) # State is initialized for both matrix and scalar parameters - config = DionParamConfig() + config = _DionParamConfig() # By default, we transpose matrices so that dim0 >= dim1 # This can change depending on sharding @@ -748,7 +748,7 @@ def dion_update_ddp( mu: Tensor, # Momentum factor (scalar tensor) weight_decay: Tensor, # Weight decay (scalar tensor) epsilon: float, - param_config: DionParamConfig, # shared for all params in batch + param_config: _DionParamConfig, # shared for all params in batch replicate_mesh: Union[DeviceMesh, ProcessGroup, None] = None, replicate_mesh_grad_sync: bool = True, oversample: float = 1.25, @@ -884,7 +884,7 @@ def dion_update_fsdp( mu: Tensor, # Momentum factor (scalar tensor) weight_decay: Tensor, # Weight decay (scalar tensor) epsilon: float, - param_config: DionParamConfig, # shared for all params in batch + param_config: _DionParamConfig, # shared for all params in batch replicate_mesh: Optional[DeviceMesh] = None, replicate_mesh_grad_sync: bool = True, oversample: float = 1.25, @@ -1021,7 +1021,7 @@ def dion_update_fsdp_tp( mu: Tensor, # Momentum factor (scalar tensor) weight_decay: Tensor, # Weight decay (scalar tensor) epsilon: float, - param_config: DionParamConfig, # shared for all params in batch + param_config: _DionParamConfig, # shared for all params in batch replicate_mesh: Optional[DeviceMesh] = None, replicate_mesh_grad_sync: bool = True, oversample: float = 1.25, diff --git a/optimizers/dion_reference.py b/optimizers/dion_reference.py index 13a3edc..e4d67bd 100644 --- a/optimizers/dion_reference.py +++ b/optimizers/dion_reference.py @@ -20,7 +20,7 @@ @dataclass -class DionParamConfig: +class _DionParamConfig: """ Per-parameter configuration for Dion optimizer. """ @@ -191,7 +191,7 @@ def __init__( # This is intentionally not in self.state so it doesn't get checkpointed # State here may change upon resharding a checkpoint, so we recompute it - self._param_config: Dict[Tensor, DionParamConfig] = {} + self._param_config: Dict[Tensor, _DionParamConfig] = {} self._replicate_mesh = replicate_mesh self._outer_shard_mesh = outer_shard_mesh @@ -393,7 +393,7 @@ def synchronize_for_checkpoint(self): result = all_reduce(tensor, self._replicate_mesh) tensor.copy_(result) - def _get_dion_param_config(self, x: Tensor) -> DionParamConfig: + def _get_dion_param_config(self, x: Tensor) -> _DionParamConfig: """ Get the Dion-specific parameter configuration for a given tensor. If the configuration is not already initialized, it will be created. @@ -424,7 +424,7 @@ def _get_dion_param_config(self, x: Tensor) -> DionParamConfig: ) # State is initialized for both matrix and scalar parameters - config = DionParamConfig() + config = _DionParamConfig() # By default, we transpose matrices so that dim0 >= dim1 # This can change depending on sharding diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..935d825 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,4 @@ +setuptools +pytest +pytest-cov +pylint \ No newline at end of file diff --git a/requirements_dion.txt b/requirements_dion.txt new file mode 100644 index 0000000..166371e --- /dev/null +++ b/requirements_dion.txt @@ -0,0 +1,3 @@ +numpy +torch>=2.7.1 +triton \ No newline at end of file diff --git a/requirements.txt b/requirements_train.txt similarity index 59% rename from requirements.txt rename to requirements_train.txt index 49db021..4f52c69 100644 --- a/requirements.txt +++ b/requirements_train.txt @@ -1,9 +1,5 @@ -numpy -torch>=2.7.1 -triton huggingface-hub wandb -einops omegaconf datasets tiktoken \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5190cfa --- /dev/null +++ b/setup.py @@ -0,0 +1,72 @@ +import os +from setuptools import find_packages +from setuptools import setup + +this_directory = os.path.dirname(__file__) +req_path = os.path.join(this_directory, "requirements_dion.txt") +req_dev_path = os.path.join(this_directory, "requirements_dev.txt") +req_train_path = os.path.join(this_directory, "requirements_train.txt") + +def read_requirements(path): + if not os.path.exists(path): + print(f"Warning: requirements file {path} does not exist.") + return [] + with open(path) as fp: + return [ + line.strip() + for line in fp + if line.strip() and not line.startswith("#") + ] + +# requirements_dion contains the dependencies for the standalone optimizer +install_requires = read_requirements(req_path) + +# requirements_dev contains the dependencies for development, e.g., testing, linting, etc. +install_dev_requires = install_requires + read_requirements(req_dev_path) + +# requirements_train contains the dependencies for training, e.g., datasets, etc. +install_train_requires = install_requires + read_requirements(req_train_path) + +readme_path = os.path.join(this_directory, "README.md") +readme_contents = "" +if os.path.exists(readme_path): + with open(readme_path, "r", encoding="utf-8") as fp: + readme_contents = fp.read().strip() + +## uncomment the following lines to read the version from a file +## will be useful if you use tools like `bump2version` to manage versions +# with open(os.path.join(this_directory, "dion/VERSION")) as f: +# version = f.read().strip() +version = "0.1.0" # versions < 1.0 are considered pre-release versions, which allow for breaking changes if necessary + +setup( + # Name of the package: + name="dion", + # Version of the package: + version=version, + # Find the package automatically (include everything): + packages=find_packages(include=["optimizers", "optimizers.*"]), + ## uncomment the following line to include version file + # package_data={ + # "dion": ["VERSION"], # Add the VERSION file + # }, + # Author information: + author="Ahn, Kwangjun and Xu, Byron and Abreu, Natalie and Langford, John", # as listed in the paper + author_email="{kwangjunahn, byronxu}@microsoft.com", # left this form to prevent bots from harvesting emails + # Description of the package: + description="Dion: Distributed Orthonormal Updates.", + long_description=readme_contents, + long_description_content_type="text/markdown", + # Plugins entry point + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + ], + python_requires=">=3.9", + license="MIT", + install_requires=install_requires, + extras_require={ + "dev": install_dev_requires, # Can be installed with `pip install dion[dev]` + "train": install_train_requires, + }, +) \ No newline at end of file From 2e9a166643dcfd2fa2623aaa49b22ce18f3a9217 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Mon, 4 Aug 2025 14:33:01 +0200 Subject: [PATCH 2/3] Moved `optimizers` to `dion` to `from dion import Dion, DionParamConfig` --- {optimizers => dion}/__init__.py | 0 {optimizers => dion}/dion.py | 0 {optimizers => dion}/dion_reference.py | 0 {optimizers => dion}/dion_simple.py | 0 {optimizers => dion}/muon.py | 0 {optimizers => dion}/muon_reference.py | 0 {optimizers => dion}/newton_schulz_triton.py | 0 {optimizers => dion}/opt_utils.py | 0 {optimizers => dion}/scalar_opts.py | 0 setup.py | 2 +- train.py | 10 +++++----- 11 files changed, 6 insertions(+), 6 deletions(-) rename {optimizers => dion}/__init__.py (100%) rename {optimizers => dion}/dion.py (100%) rename {optimizers => dion}/dion_reference.py (100%) rename {optimizers => dion}/dion_simple.py (100%) rename {optimizers => dion}/muon.py (100%) rename {optimizers => dion}/muon_reference.py (100%) rename {optimizers => dion}/newton_schulz_triton.py (100%) rename {optimizers => dion}/opt_utils.py (100%) rename {optimizers => dion}/scalar_opts.py (100%) diff --git a/optimizers/__init__.py b/dion/__init__.py similarity index 100% rename from optimizers/__init__.py rename to dion/__init__.py diff --git a/optimizers/dion.py b/dion/dion.py similarity index 100% rename from optimizers/dion.py rename to dion/dion.py diff --git a/optimizers/dion_reference.py b/dion/dion_reference.py similarity index 100% rename from optimizers/dion_reference.py rename to dion/dion_reference.py diff --git a/optimizers/dion_simple.py b/dion/dion_simple.py similarity index 100% rename from optimizers/dion_simple.py rename to dion/dion_simple.py diff --git a/optimizers/muon.py b/dion/muon.py similarity index 100% rename from optimizers/muon.py rename to dion/muon.py diff --git a/optimizers/muon_reference.py b/dion/muon_reference.py similarity index 100% rename from optimizers/muon_reference.py rename to dion/muon_reference.py diff --git a/optimizers/newton_schulz_triton.py b/dion/newton_schulz_triton.py similarity index 100% rename from optimizers/newton_schulz_triton.py rename to dion/newton_schulz_triton.py diff --git a/optimizers/opt_utils.py b/dion/opt_utils.py similarity index 100% rename from optimizers/opt_utils.py rename to dion/opt_utils.py diff --git a/optimizers/scalar_opts.py b/dion/scalar_opts.py similarity index 100% rename from optimizers/scalar_opts.py rename to dion/scalar_opts.py diff --git a/setup.py b/setup.py index 5190cfa..19ca306 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def read_requirements(path): # Version of the package: version=version, # Find the package automatically (include everything): - packages=find_packages(include=["optimizers", "optimizers.*"]), + packages=find_packages(include=["dion", "dion.*"]), ## uncomment the following line to include version file # package_data={ # "dion": ["VERSION"], # Add the VERSION file diff --git a/train.py b/train.py index ee66164..f6ce077 100644 --- a/train.py +++ b/train.py @@ -22,11 +22,11 @@ from models.gpt_model import GPT, GPTConfig, parallelize_gpt_model from models.gpt_utils import DistributedDataLoader -from optimizers.dion import Dion, DionMixedPrecisionConfig -from optimizers.dion_reference import Dion as DionReference -from optimizers.dion_simple import Dion as DionSimple -from optimizers.muon import Muon -from optimizers.muon_reference import Muon as MuonReference +from dion import Dion, DionMixedPrecisionConfig +from dion import DionReference +from dion import DionSimple +from dion import Muon +from dion import MuonReference @dataclass From b09f906f6d00539d2e5105de29eef7a0c8956871 Mon Sep 17 00:00:00 2001 From: Thibaut Boissin Date: Mon, 4 Aug 2025 14:44:43 +0200 Subject: [PATCH 3/3] moved tests in specific folder and benchmark in specific folder. --- benchmark/benchmark_newton_shultz.py | 189 +++++++++++++++++++++++++++ dion/newton_schulz_triton.py | 170 ------------------------ tests/test_newton_shultz.py | 80 ++++++++++++ 3 files changed, 269 insertions(+), 170 deletions(-) create mode 100644 benchmark/benchmark_newton_shultz.py create mode 100644 tests/test_newton_shultz.py diff --git a/benchmark/benchmark_newton_shultz.py b/benchmark/benchmark_newton_shultz.py new file mode 100644 index 0000000..115d237 --- /dev/null +++ b/benchmark/benchmark_newton_shultz.py @@ -0,0 +1,189 @@ +# benchmarks/bench_newton_schulz.py +""" +Newton-Schulz kernel benchmarks. + +Examples +-------- +# One-off timing (1024 x 1024, batch=1 & 4) +python -m benchmarks.bench_newton_schulz --m 1024 --n 1024 +python -m benchmarks.bench_newton_schulz --m 1024 --n 1024 --batch_size 4 + +# Grid sweep like the original 'benchmark_many_sizes' +python -m benchmarks.bench_newton_schulz --grid --batch_size 4 --expansion 1 + +# TFLOPS plot (writes PNG & PDF in ./plots) +python -m benchmarks.bench_newton_schulz --plot --batch_size 1 +""" +import argparse +from pathlib import Path +from typing import Iterable, Tuple +import torch +import triton.testing as tt + +from dion.newton_schulz_triton import ( + newton_schulz_triton, + zeropower_via_newtonschulz5, +) + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def gemm_cost(m: int, n: int) -> int: + """ + Return the FLOP count of the three GEMMs done per Newton-Schulz iteration. + Derivation: see paper / original comment. + """ + return 4 * m * m * n + 2 * m * m * m # == 4 m²n + 2 m³ + + +def tflops(ms: float, flops: int, steps: int, batch: int) -> float: + return batch * steps * flops * 1e-12 / (ms * 1e-3) + + +def pretty_time(ms: float) -> str: + return f"{ms:7.3f} ms" + + +def bench_once( + m: int, + n: int, + *, + batch_size: int = 1, + steps: int = 5, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float]: + """Time reference vs. Triton kernels once and return the two runtimes (ms).""" + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device required for this benchmark") + + G = torch.randn(batch_size, m, n, dtype=dtype, device="cuda") + # reference + t_ref = tt.do_bench(lambda: zeropower_via_newtonschulz5(G)) + # triton + # start with a warmup run + newton_schulz_triton(G) + # then measure the actual time + t_tri = tt.do_bench(lambda: newton_schulz_triton(G)) + + flops = gemm_cost(m, n) + ref_tflops = tflops(t_ref, flops, steps, batch_size) + tri_tflops = tflops(t_tri, flops, steps, batch_size) + + print( + f"[{batch_size=} {m=}, {n=}] " + f"torch {pretty_time(t_ref)} {ref_tflops:5.2f} TFLOPS | " + f"triton {pretty_time(t_tri)} {tri_tflops:5.2f} TFLOPS " + f"(speed-up x{t_ref/t_tri:4.2f})" + ) + return t_ref, t_tri + + +def bench_grid( + dims: Iterable[int], + *, + expansion: int = 1, + batch_size: int = 1, + dtype: torch.dtype = torch.bfloat16, +): + """Sweep over square/rectangular sizes (equiv. to original benchmark_many_sizes).""" + speedups = [] + for d in dims: + tr, tt_ = bench_once( + d, + d * expansion, + batch_size=batch_size, + dtype=dtype, + ) + speedups.append(tr / tt_) + print("Speed-ups:", ", ".join(f"{s:4.2f}x" for s in speedups)) + print("Theoretical max:", f"{(4*expansion+2)/(3*expansion+1):4.2f}x") + + +def bench_plot(batch_size: int, *, out_dir: Path = Path("plots")): + """Generate TFLOPS vs. size curves using Triton's perf_report helper.""" + if tt is None: + raise RuntimeError("Triton not available - cannot build plots") + + @tt.perf_report( + tt.Benchmark( + x_names=["dim"], + x_vals=[128 * i for i in range(1, 8)], + line_arg="provider", + line_vals=["torch", "triton"], + line_names=["torch", "triton"], + ylabel="TFLOPS", + plot_name=f"newton_schulz_batch{batch_size}", + args={"batch_size": batch_size}, + ) + ) + def bench(dim: int, provider: str, batch_size: int): + G = torch.randn(batch_size, dim, dim, dtype=torch.bfloat16, device="cuda") + if provider == "torch": + ms = tt.do_bench(lambda: zeropower_via_newtonschulz5(G)) + else: # "triton" + ms = tt.do_bench(lambda: newton_schulz_triton(G)) + return tflops(ms, gemm_cost(dim, dim), steps=5, batch=batch_size) + + bench.run(print_data=True, save_path=str(out_dir)) + + +def parse() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Benchmarks for Newton-Schulz Triton kernels" + ) + # mutually exclusive groups + mode = p.add_mutually_exclusive_group(required=True) + mode.add_argument("--single", action="store_true", help="run a single benchmark") + mode.add_argument("--grid", action="store_true", help="sweep a list of sizes") + mode.add_argument( + "--plot", action="store_true", help="generate TFLOPS curves and write plots" + ) + # single run parameters + p.add_argument("--m", type=int, help="rows") + p.add_argument("--n", type=int, help="cols (defaults to m)") + # common options + p.add_argument("--batch_size", type=int, default=1) + p.add_argument( + "--expansion", type=int, default=1, help="n = m * expansion (grid mode)" + ) + p.add_argument( + "--dtype", + default="bfloat16", + choices=["float16", "bfloat16"], + help="input dtype", + ) + return p.parse_args() + + +def main(): + args = parse() + + # -----------------------------------------------------------------------------# + # General settings + # -----------------------------------------------------------------------------# + + # Allow a lot of recompiles in Torch-Triton + torch._dynamo.config.cache_size_limit = 100 # noqa: SLF001 + + dtype = getattr(torch, args.dtype) + + if args.grid: + dims = [512, 1024, 2048, 4096, 8192] + bench_grid( + dims, + expansion=args.expansion, + batch_size=args.batch_size, + dtype=dtype, + ) + elif args.plot: + bench_plot(args.batch_size) + else: # single run + m = args.m + n = args.n or m + bench_once(m, n, batch_size=args.batch_size, dtype=dtype) + + +if __name__ == "__main__": + main() diff --git a/dion/newton_schulz_triton.py b/dion/newton_schulz_triton.py index 19e1466..30e21a8 100644 --- a/dion/newton_schulz_triton.py +++ b/dion/newton_schulz_triton.py @@ -372,173 +372,3 @@ def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7): if G.size(-2) > G.size(-1): X = X.mT return X - - -def check_result(result, correct, atol=1e-2): - assert ( - result.dtype == correct.dtype - ), f"Result dtype {result.dtype} does not match correct dtype {correct.dtype}" - assert ( - result.shape == correct.shape - ), f"Shape mismatch: {result.shape} != {correct.shape}" - - if torch.allclose(result, correct, atol=atol): - print("Test passed") - else: - print("Test failed") - if torch.allclose(result.triu(), correct.triu(), atol=atol): - print("- Upper triangular part matches") - if torch.allclose(result.tril(), correct.tril(), atol=atol): - print("- Lower triangular part matches") - abs_diff = torch.abs(result - correct) - print("- Max absolute difference:", abs_diff.max().item()) - print(abs_diff) - - -def test_ns_line_1(m, n, dtype=torch.bfloat16): - print(f"Testing ns_line_1 with shape ({m}, {n}) and dtype {dtype}") - - A = torch.randn(m, n, dtype=dtype, device="cuda") - result = ns_line_1(A) - correct = A @ A.mT - check_result(result, correct) - - # Test with batch dimension - A = torch.randn(4, m, n, dtype=dtype, device="cuda") - result = ns_line_1(A) - correct = A @ A.mT - check_result(result, correct) - - -def test_ns_line_2(m, dtype=torch.bfloat16): - print(f"Testing ns_line_2 with shape ({m}, {m}) and dtype {dtype}") - - A = torch.randn(m, m, dtype=dtype, device="cuda") - A = (A + A.mT) / 2 # Make symmetric - alpha = torch.randn(1).item() - beta = torch.randn(1).item() - result = ns_line_2(A, alpha=alpha, beta=beta) - - A = A.to(torch.float32) - correct = alpha * (A @ A.mT) + beta * A - check_result(result, correct.to(dtype)) - - # Test with batch dimension - A = torch.randn(4, m, m, dtype=dtype, device="cuda") - A = (A + A.mT) / 2 # Make symmetric - result = ns_line_2(A, alpha=alpha, beta=beta) - - A = A.to(torch.float32) - correct = alpha * (A @ A.mT) + beta * A - check_result(result, correct.to(dtype)) - - -def test_newton_schulz_triton(m, n, dtype=torch.bfloat16): - print(f"Testing newton_schulz_triton with shape ({m}, {n}) and dtype {dtype}") - - G = torch.randn(m, n, dtype=dtype, device="cuda") - result = newton_schulz_triton(G) - correct = zeropower_via_newtonschulz5(G) - check_result(result, correct) - - # Test with batch dimension - G = torch.randn(4, m, n, dtype=dtype, device="cuda") - result = newton_schulz_triton(G) - correct = zeropower_via_newtonschulz5(G) - check_result(result, correct) - - -def benchmark_newton_schulz_triton(m, n, batch_size=1, dtype=torch.bfloat16): - print( - f"Benchmarking newton_schulz_triton with shape ({m}, {n}) and batch size {batch_size}" - ) - G = torch.randn(batch_size, m, n, dtype=dtype, device="cuda") - if batch_size == 1: - G = G.squeeze(0) - - def estimate_tflops(ms): - steps = 5 # Number of Newton-Schulz iterations - mm_cost = (2 * m * n * m) + (2 * m * m * m) + (2 * m * m * n) - return batch_size * steps * mm_cost * 1e-12 / (ms * 1e-3) - - time_torch = triton.testing.do_bench(lambda: zeropower_via_newtonschulz5(G)) - print(f"Torch NS: {time_torch:.4f} ms, {estimate_tflops(time_torch):.2f} TFLOP/s") - - time_triton = triton.testing.do_bench(lambda: newton_schulz_triton(G)) - print( - f"Triton NS: {time_triton:.4f} ms, {estimate_tflops(time_triton):.2f} TFLOP/s" - ) - - print(f"Speedup: {time_torch / time_triton:.2f}x") - return time_torch, time_triton - - -def benchmark_many_sizes(batch_size=1, expansion=1, dtype=torch.bfloat16): - dim = [512, 1024, 2048, 4096, 8192] - speedups = [] - - for d in dim: - time_torch, time_triton = benchmark_newton_schulz_triton( - d, d * expansion, batch_size=batch_size, dtype=dtype - ) - time_ratio = time_torch / time_triton - speedups.append(time_ratio) - - print(f"Speedups: {speedups}") - max_speedup = (4 * expansion + 2) / (3 * expansion + 1) - print(f"Maximum theoretical speedup: {max_speedup:.2f}x") - - -def benchmark_plot(batch_size=1): - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["d"], - x_vals=[128 * i for i in range(1, 33)], - line_arg="provider", - line_names=["torch", "triton"], - line_vals=["torch", "triton"], - ylabel="TFLOPS", - plot_name=f"newton_schulz_{batch_size=}", - args={"batch_size": batch_size}, - ) - ) - def benchmark(d: int, provider: str, batch_size: int): - G = torch.randn(batch_size, d, d, dtype=torch.bfloat16, device="cuda") - - if provider == "torch": - ms = triton.testing.do_bench(lambda: zeropower_via_newtonschulz5(G)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: newton_schulz_triton(G)) - - def estimate_tflops(ms): - steps = 5 - mm_cost = (2 * d * d * d) + (2 * d * d * d) + (2 * d * d * d) - return batch_size * steps * mm_cost * 1e-12 / (ms * 1e-3) - - return estimate_tflops(ms) - - benchmark.run(print_data=True, save_path="plots") - - -if __name__ == "__main__": - # Allow a lot of recompiles - torch._dynamo.config.cache_size_limit = 100 - - # Run tests - # test_ns_line_1(1024, 1024) - # test_ns_line_1(1024, 4096) - # test_ns_line_2(1024) - # test_newton_schulz_triton(1024, 1024) - # test_newton_schulz_triton(1024, 4096) - - # d = 1024 - # benchmark_newton_schulz_triton(d, d) - # benchmark_newton_schulz_triton(d, d, batch_size=4) - - # benchmark_many_sizes(batch_size=1, expansion=1) - # benchmark_many_sizes(batch_size=4, expansion=1) - # benchmark_many_sizes(batch_size=1, expansion=4) - # benchmark_many_sizes(batch_size=4, expansion=4) - - benchmark_plot(batch_size=1) - benchmark_plot(batch_size=4) diff --git a/tests/test_newton_shultz.py b/tests/test_newton_shultz.py new file mode 100644 index 0000000..80e5903 --- /dev/null +++ b/tests/test_newton_shultz.py @@ -0,0 +1,80 @@ +# tests/test_newton_schulz.py +import pytest +import torch + +from dion.newton_schulz_triton import ( + ns_line_1, + ns_line_2, + newton_schulz_triton, + zeropower_via_newtonschulz5, +) + +# -----------------------------------------------------------------------------# +# General settings +# -----------------------------------------------------------------------------# + +# Allow a lot of recompiles in Torch-Triton +torch._dynamo.config.cache_size_limit = 100 # noqa: SLF001 + +CUDA_AVAILABLE = torch.cuda.is_available() + +# -----------------------------------------------------------------------------# +# Helper +# -----------------------------------------------------------------------------# + + +def _assert_close(result: torch.Tensor, correct: torch.Tensor, *, tol: float = 5e-2): + """Assert two tensors are close enough for the test to pass.""" + assert ( + result.dtype == correct.dtype + ), f"dtype mismatch — got {result.dtype}, expected {correct.dtype}" + assert ( + result.shape == correct.shape + ), f"shape mismatch — got {result.shape}, expected {correct.shape}" + assert torch.allclose( + result, correct, atol=tol, rtol=tol + ), f"max-abs-diff {torch.abs(result - correct).max().item():.3e} > {tol}" + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") +@pytest.mark.parametrize("m,n", [(256, 256), (256, 1024)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_ns_line_1(m: int, n: int, dtype: torch.dtype): + """ns_line_1 should compute A @ A^T (batched and unbatched).""" + A = torch.randn(m, n, dtype=dtype, device="cuda") + _assert_close(ns_line_1(A), A @ A.mT) + + A_batched = torch.randn(4, m, n, dtype=dtype, device="cuda") + _assert_close(ns_line_1(A_batched), A_batched @ A_batched.mT) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_ns_line_2(m: int, dtype: torch.dtype): + """ns_line_2 should compute alpha(A@A^T) + beta*A for symmetric A.""" + alpha, beta = torch.randn(1).item(), torch.randn(1).item() + + A = torch.randn(m, m, dtype=dtype, device="cuda") + A = (A + A.mT) / 2 # ensure symmetry + correct = alpha * (A @ A.mT) + beta * A + _assert_close(ns_line_2(A, alpha=alpha, beta=beta), correct) + + A_batched = torch.randn(4, m, m, dtype=dtype, device="cuda") + A_batched = (A_batched + A_batched.mT) / 2 + correct_batched = alpha * (A_batched @ A_batched.mT) + beta * A_batched + _assert_close(ns_line_2(A_batched, alpha=alpha, beta=beta), correct_batched) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") +@pytest.mark.parametrize("m,n", [(256, 256), (256, 1024)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_newton_schulz_triton(m: int, n: int, dtype: torch.dtype): + """Fast Triton implementation should match the reference Newton-Schulz.""" + G = torch.randn(m, n, dtype=dtype, device="cuda") + _assert_close(newton_schulz_triton(G), zeropower_via_newtonschulz5(G)) + + G_batched = torch.randn(4, m, n, dtype=dtype, device="cuda") + _assert_close( + newton_schulz_triton(G_batched), zeropower_via_newtonschulz5(G_batched) + )