From c25daf5d470dd234a0e6c62e2cbd0d04705fbb75 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Tue, 16 Dec 2025 12:30:52 -0800 Subject: [PATCH 01/14] dion2 improvements --- .gitignore | 3 +- configs/dion2_160m.yaml | 16 +- configs/muon_160m.yaml | 8 +- dion/__init__.py | 3 +- dion/dion2.py | 2 +- dion/dion2_new.py | 660 ++++++++++++++++++++++++++++++++++++++++ dion/muon.py | 2 +- train.py | 7 + 8 files changed, 685 insertions(+), 16 deletions(-) create mode 100644 dion/dion2_new.py diff --git a/.gitignore b/.gitignore index b4a7f6d..34d7e8d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ wandb/ output/ .venv/ submit.ipynb -aztool/ \ No newline at end of file +aztool/ +dion.egg-info/ \ No newline at end of file diff --git a/configs/dion2_160m.yaml b/configs/dion2_160m.yaml index d92621c..681ca23 100644 --- a/configs/dion2_160m.yaml +++ b/configs/dion2_160m.yaml @@ -1,12 +1,12 @@ # — Model — model_dim: 768 -n_layer: 12 +n_layer: 6 n_head: 6 -sequence_length: 1024 +sequence_length: 512 # — Batching & Training — -batch_size: 1024 +batch_size: 128 device_batch_size: 32 num_iterations: 3000 @@ -21,7 +21,7 @@ weight_decay: 0.01 # — Validation & Checkpointing — val_loss_every: 125 val_tokens: 10485760 -save_every: 0 +save_every: 125 # — Weights & Biases logging — wandb_project_name: gpt-train @@ -29,9 +29,9 @@ wandb_job_name: null no_wandb: false # — Distributed training — -dp_size: null # data‐parallel size -fs_size: null # FSDP size -tp_size: null # DO NOT USE TP for Dion2 +dp_size: 1 # data‐parallel size +fs_size: 4 # FSDP size +tp_size: 1 # DO NOT USE TP for Dion2 # — Miscellaneous flags — debug: false @@ -39,7 +39,7 @@ no_compile: false no_triton: false optimizer: dion2 -rank_fraction: 0.5 +rank_fraction: 0.125 scalar_opt: lion adjust_lr: spectral_norm lr: 0.02 diff --git a/configs/muon_160m.yaml b/configs/muon_160m.yaml index 7153e5e..a787585 100644 --- a/configs/muon_160m.yaml +++ b/configs/muon_160m.yaml @@ -1,13 +1,13 @@ # — Model — model_dim: 768 -n_layer: 12 +n_layer: 6 n_head: 6 -sequence_length: 1024 +sequence_length: 512 # — Batching & Training — -batch_size: 1024 -device_batch_size: 32 +batch_size: 128 +device_batch_size: 32 num_iterations: 3000 # — Learning‐rate schedule — diff --git a/dion/__init__.py b/dion/__init__.py index 34894e6..baef1ea 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -4,5 +4,6 @@ from .dion_reference import Dion as DionReference from .muon import Muon from .muon_reference import Muon as MuonReference -from .dion2 import Dion2 +from .dion2 import Dion2 as Dion2Old +from .dion2_new import Dion2 from .normuon import NorMuon diff --git a/dion/dion2.py b/dion/dion2.py index 446a983..78602c7 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -761,4 +761,4 @@ def dion2_update_post_orthogonalize( # Weight update U = torch._foreach_mul(U, adjusted_lr) - torch._foreach_sub_(X, U) + torch._foreach_sub_(X, U) \ No newline at end of file diff --git a/dion/dion2_new.py b/dion/dion2_new.py new file mode 100644 index 0000000..179e565 --- /dev/null +++ b/dion/dion2_new.py @@ -0,0 +1,660 @@ +""" +Dion2 Optimizer - Fully Optimized Implementation + +Key differences from Muon: +- Selects top-α fraction of rows (by L2 norm) for orthogonalization +- Only communicates and orthogonalizes the selected submatrix +- Applies error-feedback decay to selected rows after extraction + +Communication pattern (same as Muon): +- DDP: all-gather (each rank orthogonalizes one matrix, then gathers results) +- FSDP: all-to-all (shards → full matrix on owner → orthogonalize → shards) + +Row selection is done locally on each shard, so: +- DDP: selection on full matrix +- FSDP: selection on each shard independently (slightly different algorithm, similar performance) + +Optimizations: +- torch.compile on hot paths for kernel fusion and reduced Python overhead +- foreach operations for batched tensor updates +- Stacked tensor operations for row selection (all matrices in batch have same shape) +""" + +import math +import torch +import torch.distributed as dist +from itertools import chain +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.optim.optimizer import Optimizer, ParamsT +from typing import Callable, Generator, List, Optional, Tuple, Union + +from .newton_schulz_triton import newton_schulz_triton, zeropower_via_newtonschulz5 +from .opt_utils import ( + AsyncRuntime, + AsyncTask, + create_param_batches, + pad_batch, + to_local, +) +from .scalar_opts import adamw_update_foreach_async, lion_update_foreach_async + + +class Dion2(Optimizer): + """ + Distributed Dion2 optimizer for PyTorch FSDP2. Also compatible with DDP. + + Args: + params: Parameters for the optimizer. + distributed_mesh: DeviceMesh or ProcessGroup for distributed training. + Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. + lr: Base learning rate. Scaled based on matrix dimensions. + fraction: Fraction of rows to orthogonalize per update (0 < fraction <= 1). + ef_decay: Error-feedback decay factor applied to selected rows. + betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. + weight_decay: Weight decay factor. + epsilon: Small value to avoid division by zero. + adjust_lr: How to adjust learning rate ("spectral_norm", "rms_norm", or None). + flatten: Whether to flatten 3D+ tensors to 2D. + use_triton: Whether to use Triton kernel for Newton-Schulz. + newton_schulz_func: Custom Newton-Schulz function. + """ + + def __init__( + self, + params: ParamsT, + distributed_mesh: Optional[Union[DeviceMesh, ProcessGroup]] = None, + lr: float = 0.01, + fraction: float = 0.25, + ef_decay: float = 0.95, + betas: Tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.01, + epsilon: float = 1e-8, + adjust_lr: Optional[str] = "spectral_norm", + flatten: bool = False, + use_triton: bool = False, + newton_schulz_func: Optional[Callable] = None, + ): + # Validate hyperparameters + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if not (0.0 < fraction <= 1.0): + raise ValueError(f"fraction must be in (0, 1], got {fraction}") + if ef_decay < 0.0: + raise ValueError(f"Invalid ef_decay: {ef_decay}") + if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: + raise ValueError(f"Invalid betas: {betas}") + if adjust_lr not in ("spectral_norm", "rms_norm", None): + raise ValueError(f"Invalid adjust_lr: {adjust_lr}") + + defaults = dict( + lr=lr, + ef_decay=ef_decay, + fraction=fraction, + beta1=betas[0], + beta2=betas[1], + weight_decay=weight_decay, + epsilon=epsilon, + flatten=flatten, + adjust_lr=adjust_lr, + algorithm="dion2", + step=0, + ) + super().__init__(params, defaults) + + # Distributed configuration + if isinstance(distributed_mesh, DeviceMesh): + if distributed_mesh.ndim != 1: + raise ValueError( + f"Only 1D DeviceMesh supported, got {distributed_mesh.ndim}D." + ) + self._device_rank = distributed_mesh.get_local_rank() + self._world_size = distributed_mesh.size() + self._process_group = distributed_mesh.get_group() + elif isinstance(distributed_mesh, ProcessGroup): + self._device_rank = dist.get_rank(distributed_mesh) + self._world_size = dist.get_world_size(distributed_mesh) + self._process_group = distributed_mesh + elif distributed_mesh is None: + self._device_rank = 0 + self._world_size = 1 + self._process_group = None + else: + raise TypeError(f"Invalid distributed_mesh type: {type(distributed_mesh)}") + self._distributed_mesh = distributed_mesh + + # Newton-Schulz configuration + if newton_schulz_func is not None: + if not callable(newton_schulz_func): + raise TypeError(f"newton_schulz_func must be callable") + self._newton_schulz_func = newton_schulz_func + elif use_triton: + self._newton_schulz_func = newton_schulz_triton + else: + self._newton_schulz_func = zeropower_via_newtonschulz5 + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + dion2_groups = [] + lion_groups = [] + adamw_groups = [] + + for group in self.param_groups: + group["step"] += 1 + algo = group["algorithm"] + if algo == "dion2": + dion2_groups.append(group) + elif algo == "lion": + lion_groups.append(group) + elif algo == "adamw": + adamw_groups.append(group) + else: + raise ValueError(f"Unknown algorithm: {algo}") + + dion2_tasks = self._create_dion2_tasks(dion2_groups) + lion_tasks = self._create_lion_tasks(lion_groups) + adamw_tasks = self._create_adamw_tasks(adamw_groups) + + all_tasks = chain(dion2_tasks, lion_tasks, adamw_tasks) + runtime = AsyncRuntime(all_tasks, max_concurrent_tasks=3) + runtime.run() + + return loss + + def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: + """Initialize optimizer state (identical to Muon).""" + state = self.state[param] + if not state: + state["momentum"] = torch.zeros_like(param) + if algo == "adamw": + state["variance"] = torch.zeros_like(param) + return state + + def _create_dion2_tasks( + self, param_groups: List[dict] + ) -> Generator["AsyncTask", None, None]: + """Create batched Dion2 update tasks.""" + for group in param_groups: + assert group["algorithm"] == "dion2" + assert all(p.ndim >= 2 for p in group["params"]), \ + "Dion2 only supports matrix parameters." + + group_params = [p for p in group["params"] if p.grad is not None] + if not group_params: + continue + + # Hyperparameters as tensors for torch.compile + dion2_args = dict( + lr=torch.tensor(group["lr"]), + ef_decay=torch.tensor(group["ef_decay"]), + fraction=group["fraction"], + weight_decay=torch.tensor(group["weight_decay"]), + epsilon=torch.tensor(group["epsilon"]), + flatten=group["flatten"], + adjust_lr=group["adjust_lr"], + device_rank=self._device_rank, + world_size=self._world_size, + process_group=self._process_group, + newton_schulz_func=self._newton_schulz_func, + ) + + # Batch parameters by world_size (same as Muon) + for params in create_param_batches(group_params, batch_size=self._world_size): + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, "dion2") for p in params] + momentums = [s["momentum"] for s in states] + + # Determine sharding configuration + shard_dim = None + is_batch_sharded = False + + if isinstance(params[0], DTensor): + if not isinstance(self._distributed_mesh, DeviceMesh): + raise RuntimeError( + "Must use DeviceMesh for DTensor parameters." + ) + + # Find sharded placements (skip size-1 mesh dims) + shard_placements = [ + (i, p) + for i, p in enumerate(params[0].placements) + if p.is_shard() and params[0].device_mesh.size(i) > 1 + ] + + # Check for batch vs matrix dimension sharding + if not group["flatten"]: + matrix_dims = {params[0].ndim - 1, params[0].ndim - 2} + is_batch_sharded = any( + p.dim not in matrix_dims for _, p in shard_placements + ) + shard_placements = [ + (i, p) for i, p in shard_placements if p.dim in matrix_dims + ] + + if len(shard_placements) == 1: + shard_dim = shard_placements[0][1].dim + elif len(shard_placements) > 1: + raise NotImplementedError( + "Multiple sharded dimensions not supported." + ) + + # Verify mesh alignment + if shard_placements: + mesh_dim = shard_placements[0][0] + if params[0].device_mesh.get_group(mesh_dim) != self._process_group: + raise RuntimeError("DTensor mesh doesn't match optimizer mesh.") + + # Handle batch-sharded 3D tensors (each device has different matrices) + if is_batch_sharded: + for x, g, m in zip(params, gradients, momentums): + yield AsyncTask( + dion2_update_batch_async( + X=[x], + G=[g], + M=[m], + shard_dim=None, + **dion2_args, + ) + ) + else: + yield AsyncTask( + dion2_update_batch_async( + X=pad_batch(params, self._world_size), + G=pad_batch(gradients, self._world_size), + M=pad_batch(momentums, self._world_size), + shard_dim=shard_dim, + **dion2_args, + ) + ) + + def _create_lion_tasks( + self, param_groups: List[dict] + ) -> Generator["AsyncTask", None, None]: + """Create Lion update tasks.""" + for group in param_groups: + assert group["algorithm"] == "lion" + + params = [p for p in group["params"] if p.grad is not None] + if not params: + continue + + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, "lion") for p in params] + momentums = [s["momentum"] for s in states] + + yield AsyncTask( + lion_update_foreach_async( + X=to_local(params), + G=to_local(gradients), + M=to_local(momentums), + lr=torch.tensor(group["lr"]), + beta1=torch.tensor(group["beta1"]), + beta2=torch.tensor(group["beta2"]), + weight_decay=torch.tensor(group["weight_decay"]), + ) + ) + + def _create_adamw_tasks( + self, param_groups: List[dict] + ) -> Generator["AsyncTask", None, None]: + """Create AdamW update tasks.""" + for group in param_groups: + assert group["algorithm"] == "adamw" + + params = [p for p in group["params"] if p.grad is not None] + if not params: + continue + + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, "adamw") for p in params] + momentums = [s["momentum"] for s in states] + variances = [s["variance"] for s in states] + + yield AsyncTask( + adamw_update_foreach_async( + X=to_local(params), + G=to_local(gradients), + M=to_local(momentums), + V=to_local(variances), + lr=torch.tensor(group["lr"]), + beta1=torch.tensor(group["beta1"]), + beta2=torch.tensor(group["beta2"]), + weight_decay=torch.tensor(group["weight_decay"]), + step=torch.tensor(group["step"]), + epsilon=torch.tensor(group["epsilon"]), + ) + ) + + +# ============================================================================= +# Core Dion2 Update Functions +# ============================================================================= + +def dion2_update_batch_async( + X: List[Tensor], # Parameters (DTensor or Tensor), padded to world_size + G: List[Tensor], # Gradients, padded to world_size + M: List[Tensor], # Momentum buffers (modified in place), padded to world_size + lr: Tensor, + ef_decay: Tensor, + fraction: float, + weight_decay: Tensor, + epsilon: Tensor, + flatten: bool, + adjust_lr: Optional[str], + device_rank: int, + world_size: int, + shard_dim: Optional[int] = None, + process_group: Optional[ProcessGroup] = None, + newton_schulz_func: Optional[Callable] = None, +) -> Generator[None, None, None]: + """ + Batched Dion2 update with fractional row selection. + + Algorithm: + 1. Update momentum: M = M + G + 2. Select top-α rows by L2 norm, extract submatrix + 3. Apply ef_decay to selected rows in M + 4. Communicate and orthogonalize only the submatrix + 5. Apply weight update to corresponding rows + + Communication patterns: + - FSDP (shard_dim is not None): + - Parameters are row-sharded across ranks + - Each rank selects top-k rows from its local shard + - All-to-all gathers selected rows to form full submatrix + - Orthogonalize, then all-to-all scatter back + - DDP (shard_dim is None, world_size > 1): + - Each rank has full matrices (batch of different matrices) + - Each rank orthogonalizes one matrix from the batch + - All-gather to distribute results + - Single GPU: direct computation + """ + assert len(X) == len(G) == len(M) + + # Step 1: Update momentum and select top-α rows (operates on local shards) + # All matrices in batch have identical shapes, enabling stacked operations + U_selected, row_indices_list = dion2_pre_orthogonalize( + G=to_local(G), + M=to_local(M), + fraction=fraction, + ef_decay=ef_decay, + ) + + # Step 2: Communicate and orthogonalize selected submatrices + # ------------------------------------------------------------------------- + # FSDP path: all-to-all + # ------------------------------------------------------------------------- + if shard_dim is not None: + assert len(X) == world_size + assert process_group is not None + assert isinstance(X[0], DTensor) + + recv_shards = [torch.empty_like(u) for u in U_selected] + work = dist.all_to_all(recv_shards, U_selected, group=process_group, async_op=True) + yield + work.wait() + + # Concatenate along row dimension to form full selected submatrix + full_submatrix = torch.cat(recv_shards, dim=-2) + + # Orthogonalize the full selected submatrix + full_submatrix = dion2_newton_schulz( + full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon + ) + + # Split back into shards + send_shards = [ + t.contiguous() + for t in torch.tensor_split(full_submatrix, world_size, dim=-2) + ] + + # All-to-all: scatter orthogonalized shards back to original owners + U_ortho = [torch.empty_like(u) for u in U_selected] + work = dist.all_to_all(U_ortho, send_shards, group=process_group, async_op=True) + yield + work.wait() + + # ------------------------------------------------------------------------- + # DDP path: all-gather + # ------------------------------------------------------------------------- + elif len(U_selected) > 1: + assert len(U_selected) == world_size + assert process_group is not None + + # This rank orthogonalizes the matrix at index device_rank + my_submatrix = dion2_newton_schulz( + U_selected[device_rank], newton_schulz_func, flatten=flatten, epsilon=epsilon + ) + + # All-gather: collect orthogonalized submatrices from all ranks + U_ortho = [torch.empty_like(u) for u in U_selected] + work = dist.all_gather( + U_ortho, my_submatrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + + # ------------------------------------------------------------------------- + # Single GPU path + # ------------------------------------------------------------------------- + else: + assert len(U_selected) == 1 + U_ortho = [ + dion2_newton_schulz( + U_selected[0], newton_schulz_func, flatten=flatten, epsilon=epsilon + ) + ] + + # Step 3: Compute adjusted learning rate (based on full/global matrix shape) + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = _adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = _adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr: {adjust_lr}") + + # Step 4: Apply weight update to selected rows only + dion2_post_orthogonalize( + X=to_local(X), + U_ortho=U_ortho, + row_indices=row_indices_list, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + +# ============================================================================= +# Optimized Pre-Orthogonalize Function (Stacked Operations) +# ============================================================================= +# +# KEY INSIGHT: All matrices in a batch have identical shapes! +# This enables stacked/batched tensor operations instead of loops. +# +# OPTIMIZATION 1: Stack into 3D tensor for batched ops +# ---------------------------------------------------- +# Stack (N, rows, cols) enables: +# - Single batched norm instead of N separate norms +# - Single batched topk instead of N separate topk calls +# - Single batched gather instead of N separate index_selects +# +# Why faster: +# - One kernel launch instead of N launches +# - Better GPU parallelism +# - Reduced Python loop overhead +# +# OPTIMIZATION 2: In-place ef_decay via loop (unavoidable) +# -------------------------------------------------------- +# torch.stack creates a copy, so we must apply ef_decay to originals. +# However, the loop benefits from torch.compile fusion. +# +# OPTIMIZATION 3: foreach for gradient accumulation +# ------------------------------------------------- +# Optimal for in-place batched additions. +# ============================================================================= + +@torch.compile(fullgraph=True) +def dion2_pre_orthogonalize( + G: List[Tensor], + M: List[Tensor], + fraction: float, + ef_decay: Tensor, +) -> Tuple[List[Tensor], List[Tensor]]: + """ + Update momentum and select top-α rows for orthogonalization. + + All matrices in the batch have identical shapes, enabling stacked operations. + + For each matrix M (shape: rows x cols): + 1. M += G (accumulate gradient into momentum) + 2. Compute L2 norm of each row + 3. Select top-k rows where k = ceil(fraction * rows) + 4. Extract selected rows as submatrix (k x cols) + 5. Apply ef_decay to selected rows in M (in-place) + + Returns: + U_selected: List of selected submatrices in bf16 for communication + row_indices: List of selected row indices for each matrix + """ + dtype = M[0].dtype + num_rows = M[0].size(-2) + num_cols = M[0].size(-1) + k = max(1, int(math.ceil(fraction * num_rows))) + + # OPTIMIZATION 1: foreach for batched gradient accumulation + # Single fused kernel for all M += G operations + G_casted = [g.to(dtype=dtype) for g in G] + torch._foreach_add_(M, G_casted) + + # OPTIMIZATION 2: Stack for batched norm and topk + # Shape: (batch_size, num_rows, num_cols) + M_stacked = torch.stack(M, dim=0) + + # Batched L2 norm: (batch_size, num_rows) + row_norms = M_stacked.norm(dim=-1) + + # Batched topk: indices shape (batch_size, k) + _, indices = torch.topk(row_norms, k, dim=-1, sorted=False) + + # OPTIMIZATION 3: Batched gather for row extraction + # (batch_size, k, num_cols) + indices_expanded = indices.unsqueeze(-1).expand(-1, -1, num_cols) + selected_stacked = torch.gather(M_stacked, dim=-2, index=indices_expanded) + + # Apply ef_decay to selected rows in original M tensors + # Must loop because M tensors are separate (stack created a copy) + # torch.compile will still optimize this loop + row_indices_list = list(indices.unbind(dim=0)) + for m, idx in zip(M, row_indices_list): + m[idx, :] *= ef_decay + + # Convert to bf16 and unstack for communication + U_selected = list(selected_stacked.to(dtype=torch.bfloat16).unbind(dim=0)) + + return U_selected, row_indices_list + + +# ============================================================================= +# Optimized Post-Orthogonalize Function +# ============================================================================= +# +# OPTIMIZATION 1: foreach for weight decay +# ---------------------------------------- +# Single fused kernel for all X *= (1 - lr * wd) operations. +# +# OPTIMIZATION 2: Batched dtype conversion +# ---------------------------------------- +# Convert all U tensors upfront for better memory planning. +# +# OPTIMIZATION 3: Loop with index_add_ (torch.compile optimized) +# -------------------------------------------------------------- +# While we can't use foreach for indexed updates, torch.compile +# will fuse operations within each iteration and optimize the loop. +# +# Note: Stacking X would require copy-back which negates benefits. +# The loop approach is cleaner and torch.compile handles it well. +# ============================================================================= + +@torch.compile(fullgraph=True) +def dion2_post_orthogonalize( + X: List[Tensor], + U_ortho: List[Tensor], + row_indices: List[Tensor], + base_lr: Tensor, + adjusted_lr: Tensor, + weight_decay: Tensor, +): + """ + Apply weight decay (to all rows) and update selected rows only. + + Weight decay: X = X * (1 - base_lr * weight_decay) [all rows] + Update: X[selected_rows] -= adjusted_lr * U_ortho [selected rows only] + """ + # OPTIMIZATION 1: foreach for batched weight decay + torch._foreach_mul_(X, 1 - base_lr * weight_decay) + + # OPTIMIZATION 2: Batch dtype conversion upfront + dtype = X[0].dtype + U_converted = [u.to(dtype=dtype) for u in U_ortho] + + # OPTIMIZATION 3: Precompute scaled updates + # This allows torch.compile to potentially fuse with index_add_ + neg_lr = -adjusted_lr + U_scaled = [neg_lr * u for u in U_converted] + + # Apply updates to selected rows + # torch.compile optimizes this loop + for x, u_scaled, indices in zip(X, U_scaled, row_indices): + x.index_add_(dim=-2, index=indices, source=u_scaled) + + +# ============================================================================= +# Newton-Schulz Wrapper (unchanged) +# ============================================================================= + +def dion2_newton_schulz( + X: Tensor, + newton_schulz_func: Callable, + flatten: bool, + epsilon: Tensor, +) -> Tensor: + """Apply Newton-Schulz orthogonalization with optional flattening.""" + original_shape = X.shape + if flatten and X.ndim >= 3: + X = X.flatten(start_dim=1) + elif X.ndim >= 4: + X = X.flatten(end_dim=-3) + + return newton_schulz_func(X, epsilon=epsilon).reshape(original_shape) + + +# ============================================================================= +# Learning Rate Adjustment Functions (unchanged) +# ============================================================================= + +def _adjust_lr_spectral_norm(lr: Tensor, param_shape: torch.Size, flatten: bool) -> Tensor: + """Adjust LR based on spectral norm (for scale transfer).""" + if flatten: + fan_out = param_shape[0] + fan_in = math.prod(param_shape[1:]) + else: + fan_out, fan_in = param_shape[-2:] + return lr * math.sqrt(fan_out / fan_in) + + +def _adjust_lr_rms_norm(lr: Tensor, param_shape: torch.Size, flatten: bool) -> Tensor: + """Adjust LR based on RMS norm (for Adam/AdamW compatibility).""" + if flatten: + fan_out = param_shape[0] + fan_in = math.prod(param_shape[1:]) + else: + fan_out, fan_in = param_shape[-2:] + return lr * 0.2 * math.sqrt(max(fan_out, fan_in)) \ No newline at end of file diff --git a/dion/muon.py b/dion/muon.py index 92f444d..0966b84 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -672,4 +672,4 @@ def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): if G.size(-2) > G.size(-1): X = X.mT - return X + return X \ No newline at end of file diff --git a/train.py b/train.py index e993d28..7b80672 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,5 @@ +#train.py below + import argparse import math import os @@ -28,6 +30,7 @@ from dion import Muon from dion import MuonReference from dion import Dion2 +from dion import Dion2 as Dion2Old from dion import NorMuon @@ -1005,6 +1008,10 @@ def get_lr(it): # Otherwise, checkpoint results will not be consistent optimizer.synchronize_for_checkpoint() + + + + # Save a distributed checkpoint checkpoint_manager.save(step=step) From cb7f0abdbda440a96f98d91724560c86a3c96995 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Tue, 16 Dec 2025 13:05:56 -0800 Subject: [PATCH 02/14] some tests --- configs/dion2_160m.yaml | 10 +++++----- dion/__init__.py | 2 +- dion/dion2.py | 6 ++++-- train.py | 28 +++++++++++++++++++++++++++- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/configs/dion2_160m.yaml b/configs/dion2_160m.yaml index 681ca23..6a1337e 100644 --- a/configs/dion2_160m.yaml +++ b/configs/dion2_160m.yaml @@ -29,17 +29,17 @@ wandb_job_name: null no_wandb: false # — Distributed training — -dp_size: 1 # data‐parallel size -fs_size: 4 # FSDP size -tp_size: 1 # DO NOT USE TP for Dion2 +dp_size: null # data‐parallel size +fs_size: null # FSDP size +tp_size: null # DO NOT USE TP for Dion2 # — Miscellaneous flags — debug: false no_compile: false no_triton: false -optimizer: dion2 -rank_fraction: 0.125 +optimizer: dion2old +rank_fraction: 0.5 scalar_opt: lion adjust_lr: spectral_norm lr: 0.02 diff --git a/dion/__init__.py b/dion/__init__.py index baef1ea..f14605b 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -4,6 +4,6 @@ from .dion_reference import Dion as DionReference from .muon import Muon from .muon_reference import Muon as MuonReference -from .dion2 import Dion2 as Dion2Old +from .dion2 import Dion2Old from .dion2_new import Dion2 from .normuon import NorMuon diff --git a/dion/dion2.py b/dion/dion2.py index 78602c7..e8dfd3a 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -35,7 +35,7 @@ def _full_dtype_and_shape(p: Tensor) -> Tuple[torch.Size, torch.dtype, torch.dev return p.size(), p.dtype, p.device -class Dion2(Optimizer): +class Dion2Old(Optimizer): """ Distributed Dion2 optimizer for PyTorch FSDP2. Also compatible with DDP. @@ -90,7 +90,8 @@ def __init__( raise ValueError( f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." ) - + + print(f"================Dion2 Old Optimizer initialized===========") defaults = dict( lr=lr, ef_decay=ef_decay, @@ -140,6 +141,7 @@ def __init__( self._newton_schulz_func = newton_schulz_triton else: self._newton_schulz_func = zeropower_via_newtonschulz5 + @torch.no_grad() def step(self, closure=None): diff --git a/train.py b/train.py index 7b80672..f8c2dcc 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,7 @@ from dion import Muon from dion import MuonReference from dion import Dion2 -from dion import Dion2 as Dion2Old +from dion import Dion2Old from dion import NorMuon @@ -454,6 +454,32 @@ def init_optimizer( adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), ) + elif hp.optimizer == "dion2old": + if device_mesh is not None: + # Ensure that we have a supported device mesh configuration for dion2 + if inner_shard_mesh is not None and inner_shard_mesh.size() > 1: + raise ValueError("Tensor parallel is not supported by dion2.") + distributed_mesh = ( + outer_shard_mesh if outer_shard_mesh.size() > 1 else replicate_mesh + ) + comm_method = "all-to-all" if outer_shard_mesh.size() > 1 else "all-gather" + else: + assert ddp_model is not None + distributed_mesh = ddp_model.process_group # using ProcessGroup for DDP + comm_method = "all-gather" + print0(f"LR adjust method: {hp.adjust_lr}") + print0(f"Triton Newton-Schulz kernels: {not cli_args.no_triton}") + print0(f"Distributed Dion2Old using: {comm_method}") + opt = Dion2Old( + param_groups, + distributed_mesh=distributed_mesh, + lr=hp.lr, + fraction=hp.rank_fraction, + ef_decay=hp.mu, + weight_decay=hp.weight_decay, + adjust_lr=hp.adjust_lr, + use_triton=(not cli_args.no_triton), + ) elif hp.optimizer == "normuon": if device_mesh is not None: # Ensure that we have a supported device mesh configuration for NorMuon From ed196d34ca0eb53ea89a8972111925cd45606bc1 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Wed, 17 Dec 2025 05:18:59 -0800 Subject: [PATCH 03/14] debug stmt --- configs/dion2_160m.yaml | 8 +- dion/dion2_new.py | 355 ++++++++++++++++++++++++---------------- train.py | 3 +- 3 files changed, 223 insertions(+), 143 deletions(-) diff --git a/configs/dion2_160m.yaml b/configs/dion2_160m.yaml index 6a1337e..9ba8055 100644 --- a/configs/dion2_160m.yaml +++ b/configs/dion2_160m.yaml @@ -29,16 +29,16 @@ wandb_job_name: null no_wandb: false # — Distributed training — -dp_size: null # data‐parallel size -fs_size: null # FSDP size -tp_size: null # DO NOT USE TP for Dion2 +dp_size: 1 # data‐parallel size +fs_size: 4 # FSDP size +tp_size: 1 # DO NOT USE TP for Dion2 # — Miscellaneous flags — debug: false no_compile: false no_triton: false -optimizer: dion2old +optimizer: dion2 rank_fraction: 0.5 scalar_opt: lion adjust_lr: spectral_norm diff --git a/dion/dion2_new.py b/dion/dion2_new.py index 179e565..577a45c 100644 --- a/dion/dion2_new.py +++ b/dion/dion2_new.py @@ -1,24 +1,3 @@ -""" -Dion2 Optimizer - Fully Optimized Implementation - -Key differences from Muon: -- Selects top-α fraction of rows (by L2 norm) for orthogonalization -- Only communicates and orthogonalizes the selected submatrix -- Applies error-feedback decay to selected rows after extraction - -Communication pattern (same as Muon): -- DDP: all-gather (each rank orthogonalizes one matrix, then gathers results) -- FSDP: all-to-all (shards → full matrix on owner → orthogonalize → shards) - -Row selection is done locally on each shard, so: -- DDP: selection on full matrix -- FSDP: selection on each shard independently (slightly different algorithm, similar performance) - -Optimizations: -- torch.compile on hot paths for kernel fusion and reduced Python overhead -- foreach operations for batched tensor updates -- Stacked tensor operations for row selection (all matrices in batch have same shape) -""" import math import torch @@ -40,6 +19,12 @@ ) from .scalar_opts import adamw_update_foreach_async, lion_update_foreach_async +# Reuse Muon's helper functions +from .muon import ( + muon_update_newton_schulz, + adjust_lr_spectral_norm, + adjust_lr_rms_norm, +) class Dion2(Optimizer): """ @@ -49,17 +34,27 @@ class Dion2(Optimizer): params: Parameters for the optimizer. distributed_mesh: DeviceMesh or ProcessGroup for distributed training. Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. - lr: Base learning rate. Scaled based on matrix dimensions. - fraction: Fraction of rows to orthogonalize per update (0 < fraction <= 1). - ef_decay: Error-feedback decay factor applied to selected rows. + lr: Base learning rate. For Muon, this will be scaled based on the matrix dimensions. + For element-wise update rules, this is the actual learning rate and no additional scaling is done. + fraction: Fraction of submatrix to orthogonalize per update (0 < fraction <= 1). + ef_decay: Error-feedback decay factor applied to selected submatrix. betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. - weight_decay: Weight decay factor. - epsilon: Small value to avoid division by zero. - adjust_lr: How to adjust learning rate ("spectral_norm", "rms_norm", or None). - flatten: Whether to flatten 3D+ tensors to 2D. - use_triton: Whether to use Triton kernel for Newton-Schulz. - newton_schulz_func: Custom Newton-Schulz function. - """ + weight_decay: Weight decay factor. + epsilon: Small value to avoid division by zero. + adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). + "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. + "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. + None: Do not adjust the learning rate. + flatten: Whether to flatten 3D+ tensors to 2D for Muon updates. + True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. + False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. + use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. + newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. + Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + verbose: Whether to print debug information during updates. This prints whether rows or columns are selected for submatrix selection process. + + Dion2 optimizer by Ahn et al.: TBD + """ def __init__( self, @@ -75,6 +70,7 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, + verbose: bool = False, ): # Validate hyperparameters if lr < 0.0: @@ -133,7 +129,7 @@ def __init__( self._newton_schulz_func = newton_schulz_triton else: self._newton_schulz_func = zeropower_via_newtonschulz5 - + self.verbose = verbose @torch.no_grad() def step(self, closure=None): loss = None @@ -157,7 +153,7 @@ def step(self, closure=None): else: raise ValueError(f"Unknown algorithm: {algo}") - dion2_tasks = self._create_dion2_tasks(dion2_groups) + dion2_tasks = self._create_dion2_tasks(dion2_groups, verbose=self.verbose) lion_tasks = self._create_lion_tasks(lion_groups) adamw_tasks = self._create_adamw_tasks(adamw_groups) @@ -177,7 +173,7 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: return state def _create_dion2_tasks( - self, param_groups: List[dict] + self, param_groups: List[dict], verbose: bool = False, ) -> Generator["AsyncTask", None, None]: """Create batched Dion2 update tasks.""" for group in param_groups: @@ -260,6 +256,7 @@ def _create_dion2_tasks( M=[m], shard_dim=None, **dion2_args, + verbose=verbose, ) ) else: @@ -270,6 +267,7 @@ def _create_dion2_tasks( M=pad_batch(momentums, self._world_size), shard_dim=shard_dim, **dion2_args, + verbose=verbose, ) ) @@ -332,6 +330,7 @@ def _create_adamw_tasks( ) + # ============================================================================= # Core Dion2 Update Functions # ============================================================================= @@ -352,38 +351,81 @@ def dion2_update_batch_async( shard_dim: Optional[int] = None, process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, + verbose: bool = False, ) -> Generator[None, None, None]: """ - Batched Dion2 update with fractional row selection. + Batched Dion2 update with fractional submatrix selection. Algorithm: 1. Update momentum: M = M + G - 2. Select top-α rows by L2 norm, extract submatrix - 3. Apply ef_decay to selected rows in M + 2. Select top-α fraction along select_dim by L2 norm, extract submatrix + 3. Apply ef_decay to selected slices in M 4. Communicate and orthogonalize only the submatrix - 5. Apply weight update to corresponding rows + 5. Apply weight update to corresponding slices + + Selection dimension (select_dim): + - FSDP row-sharded (shard_dim=-2): select rows (select_dim=-2), row norms are local + - FSDP col-sharded (shard_dim=-1): select cols (select_dim=-1), col norms are local + - DDP/Single GPU: select rows by default (select_dim=-2) Communication patterns: - FSDP (shard_dim is not None): - - Parameters are row-sharded across ranks - - Each rank selects top-k rows from its local shard - - All-to-all gathers selected rows to form full submatrix + - All-to-all gathers selected slices to form full submatrix - Orthogonalize, then all-to-all scatter back - DDP (shard_dim is None, world_size > 1): - - Each rank has full matrices (batch of different matrices) - Each rank orthogonalizes one matrix from the batch - All-gather to distribute results - Single GPU: direct computation """ assert len(X) == len(G) == len(M) - # Step 1: Update momentum and select top-α rows (operates on local shards) + # Determine selection dimension based on sharding + # + # shard_dim from DTensor can be: + # - Absolute index (0, 1, 2, ...) + # - Negative index (-2, -1) + # - None (not sharded) + # + # We need to map this to select_dim which is always -2 (rows) or -1 (cols) + # relative to the last two dimensions (the matrix dimensions). + # + # For FSDP: select along the sharded dimension so norms are local + # For DDP/Single-GPU: select along the SHORTER dimension to reduce Newton-Schulz compute + + ndim = X[0].ndim + + if shard_dim is not None: + # Convert shard_dim to normalized form relative to tensor + normalized_shard_dim = shard_dim if shard_dim < 0 else shard_dim - ndim + + # Check if shard_dim corresponds to a matrix dimension (last two dims) + if normalized_shard_dim == -2: + # Row-sharded: select rows, compute row norms (norm over cols) + select_dim = -2 + elif normalized_shard_dim == -1: + # Col-sharded: select cols, compute col norms (norm over rows) + select_dim = -1 + else: + # Batch dimension sharded: not a matrix dim, fall back to shorter dim + num_rows, num_cols = X[0].shape[-2:] + select_dim = -2 if num_rows <= num_cols else -1 + else: + # DDP/Single-GPU: choose shorter dimension to reduce Newton-Schulz compute + num_rows, num_cols = X[0].shape[-2:] + select_dim = -2 if num_rows <= num_cols else -1 + + # Debug: Print selection choice (only on first call per parameter shape) + if verbose: + _print_selection_choice(X[0].shape, shard_dim, select_dim, ndim) + + # Step 1: Update momentum and select top-α fraction along select_dim # All matrices in batch have identical shapes, enabling stacked operations - U_selected, row_indices_list = dion2_pre_orthogonalize( + U_selected, indices_list = dion2_pre_orthogonalize( G=to_local(G), M=to_local(M), fraction=fraction, ef_decay=ef_decay, + select_dim=select_dim, ) # Step 2: Communicate and orthogonalize selected submatrices @@ -400,18 +442,19 @@ def dion2_update_batch_async( yield work.wait() - # Concatenate along row dimension to form full selected submatrix - full_submatrix = torch.cat(recv_shards, dim=-2) + # Concatenate along selection dimension to form full selected submatrix + # select_dim matches shard_dim, so we concatenate along that dimension + full_submatrix = torch.cat(recv_shards, dim=select_dim) # Orthogonalize the full selected submatrix - full_submatrix = dion2_newton_schulz( + full_submatrix = muon_update_newton_schulz( full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon ) - # Split back into shards + # Split back into shards along the same dimension send_shards = [ t.contiguous() - for t in torch.tensor_split(full_submatrix, world_size, dim=-2) + for t in torch.tensor_split(full_submatrix, world_size, dim=select_dim) ] # All-to-all: scatter orthogonalized shards back to original owners @@ -428,7 +471,7 @@ def dion2_update_batch_async( assert process_group is not None # This rank orthogonalizes the matrix at index device_rank - my_submatrix = dion2_newton_schulz( + my_submatrix = muon_update_newton_schulz( U_selected[device_rank], newton_schulz_func, flatten=flatten, epsilon=epsilon ) @@ -446,7 +489,7 @@ def dion2_update_batch_async( else: assert len(U_selected) == 1 U_ortho = [ - dion2_newton_schulz( + muon_update_newton_schulz( U_selected[0], newton_schulz_func, flatten=flatten, epsilon=epsilon ) ] @@ -455,20 +498,21 @@ def dion2_update_batch_async( if adjust_lr is None: adjusted_lr = lr elif adjust_lr == "spectral_norm": - adjusted_lr = _adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) elif adjust_lr == "rms_norm": - adjusted_lr = _adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) else: raise ValueError(f"Unknown adjust_lr: {adjust_lr}") - # Step 4: Apply weight update to selected rows only + # Step 4: Apply weight update to selected slices only dion2_post_orthogonalize( X=to_local(X), U_ortho=U_ortho, - row_indices=row_indices_list, + indices=indices_list, base_lr=lr, adjusted_lr=adjusted_lr, weight_decay=weight_decay, + select_dim=select_dim, ) @@ -479,6 +523,17 @@ def dion2_update_batch_async( # KEY INSIGHT: All matrices in a batch have identical shapes! # This enables stacked/batched tensor operations instead of loops. # +# SELECTION DIMENSION: +# - select_dim=-2 (rows): Compute row norms (norm over cols), select top-k rows +# - select_dim=-1 (cols): Compute col norms (norm over rows), select top-k cols +# +# For FSDP, select_dim matches shard_dim so norms are computed locally. +# For DDP/Single-GPU, select_dim is the shorter dimension to reduce compute. +# +# NORM CHOICE: L1 norm (sum of absolute values) +# - Cheaper than L2: no squaring or sqrt needed +# - Effective proxy for selecting high-magnitude rows/cols +# # OPTIMIZATION 1: Stack into 3D tensor for batched ops # ---------------------------------------------------- # Stack (N, rows, cols) enables: @@ -486,17 +541,7 @@ def dion2_update_batch_async( # - Single batched topk instead of N separate topk calls # - Single batched gather instead of N separate index_selects # -# Why faster: -# - One kernel launch instead of N launches -# - Better GPU parallelism -# - Reduced Python loop overhead -# -# OPTIMIZATION 2: In-place ef_decay via loop (unavoidable) -# -------------------------------------------------------- -# torch.stack creates a copy, so we must apply ef_decay to originals. -# However, the loop benefits from torch.compile fusion. -# -# OPTIMIZATION 3: foreach for gradient accumulation +# OPTIMIZATION 2: foreach for gradient accumulation # ------------------------------------------------- # Optimal for in-place batched additions. # ============================================================================= @@ -505,61 +550,85 @@ def dion2_update_batch_async( def dion2_pre_orthogonalize( G: List[Tensor], M: List[Tensor], - fraction: float, + fraction: Tensor, ef_decay: Tensor, + select_dim: int, ) -> Tuple[List[Tensor], List[Tensor]]: """ - Update momentum and select top-α rows for orthogonalization. + Update momentum and select top-α fraction along select_dim. All matrices in the batch have identical shapes, enabling stacked operations. + Args: + G: List of gradients + M: List of momentum buffers (modified in place) + fraction: Fraction of rows/cols to select + ef_decay: Decay factor for selected slices + select_dim: Dimension to select along (-2 for rows, -1 for cols) + For each matrix M (shape: rows x cols): 1. M += G (accumulate gradient into momentum) - 2. Compute L2 norm of each row - 3. Select top-k rows where k = ceil(fraction * rows) - 4. Extract selected rows as submatrix (k x cols) - 5. Apply ef_decay to selected rows in M (in-place) + 2. Compute L1 norm along the OTHER dimension + - select_dim=-2: norm over cols (dim=-1) → row norms + - select_dim=-1: norm over rows (dim=-2) → col norms + 3. Select top-k indices where k = ceil(fraction * size_of_select_dim) + 4. Extract selected slices as submatrix + 5. Apply ef_decay to selected slices in M (in-place) Returns: U_selected: List of selected submatrices in bf16 for communication - row_indices: List of selected row indices for each matrix + indices_list: List of selected indices for each matrix """ dtype = M[0].dtype - num_rows = M[0].size(-2) - num_cols = M[0].size(-1) - k = max(1, int(math.ceil(fraction * num_rows))) + + # Determine sizes and norm dimension + # norm_dim is the dimension we compute norm OVER (the other dimension) + # select_dim is the dimension we SELECT from + num_select = M[0].size(select_dim) + norm_dim = -1 if select_dim == -2 else -2 + k = max(1, int(math.ceil(fraction * num_select))) # OPTIMIZATION 1: foreach for batched gradient accumulation - # Single fused kernel for all M += G operations G_casted = [g.to(dtype=dtype) for g in G] torch._foreach_add_(M, G_casted) # OPTIMIZATION 2: Stack for batched norm and topk - # Shape: (batch_size, num_rows, num_cols) + # Shape: (batch_size, rows, cols) M_stacked = torch.stack(M, dim=0) - # Batched L2 norm: (batch_size, num_rows) - row_norms = M_stacked.norm(dim=-1) + # Compute L1 norm along norm_dim (sum of absolute values) + # - If select_dim=-2 (rows): norm over dim=-1 → shape (batch, rows) + # - If select_dim=-1 (cols): norm over dim=-2 → shape (batch, cols) + slice_norms = M_stacked.norm(p=1, dim=norm_dim) # Batched topk: indices shape (batch_size, k) - _, indices = torch.topk(row_norms, k, dim=-1, sorted=False) + _, indices = torch.topk(slice_norms, k, dim=-1, sorted=False) - # OPTIMIZATION 3: Batched gather for row extraction - # (batch_size, k, num_cols) - indices_expanded = indices.unsqueeze(-1).expand(-1, -1, num_cols) - selected_stacked = torch.gather(M_stacked, dim=-2, index=indices_expanded) + # OPTIMIZATION 3: Batched gather for slice extraction + if select_dim == -2: + # Selecting rows: expand indices to (..., k, cols) + num_cols = M[0].size(-1) + indices_expanded = indices.unsqueeze(-1).expand(-1, -1, num_cols) + selected_stacked = torch.gather(M_stacked, dim=-2, index=indices_expanded) + else: + # Selecting cols: expand indices to (..., rows, k) + num_rows = M[0].size(-2) + indices_expanded = indices.unsqueeze(-2).expand(-1, num_rows, -1) + selected_stacked = torch.gather(M_stacked, dim=-1, index=indices_expanded) - # Apply ef_decay to selected rows in original M tensors + # Apply ef_decay to selected slices in original M tensors # Must loop because M tensors are separate (stack created a copy) - # torch.compile will still optimize this loop - row_indices_list = list(indices.unbind(dim=0)) - for m, idx in zip(M, row_indices_list): - m[idx, :] *= ef_decay + # Use index_copy_ with proper dimension handling for arbitrary batch dims + indices_list = list(indices.unbind(dim=0)) + for m, idx in zip(M, indices_list): + # Extract, scale, and copy back using the correct dimension + selected_slice = m.index_select(dim=select_dim, index=idx) + m.index_copy_(dim=select_dim, index=idx, source=selected_slice * ef_decay) # Convert to bf16 and unstack for communication U_selected = list(selected_stacked.to(dtype=torch.bfloat16).unbind(dim=0)) - return U_selected, row_indices_list + return U_selected, indices_list # ============================================================================= @@ -576,27 +645,33 @@ def dion2_pre_orthogonalize( # # OPTIMIZATION 3: Loop with index_add_ (torch.compile optimized) # -------------------------------------------------------------- -# While we can't use foreach for indexed updates, torch.compile -# will fuse operations within each iteration and optimize the loop. -# -# Note: Stacking X would require copy-back which negates benefits. -# The loop approach is cleaner and torch.compile handles it well. +# torch.compile fuses operations within each iteration. # ============================================================================= @torch.compile(fullgraph=True) def dion2_post_orthogonalize( X: List[Tensor], U_ortho: List[Tensor], - row_indices: List[Tensor], + indices: List[Tensor], base_lr: Tensor, adjusted_lr: Tensor, weight_decay: Tensor, + select_dim: int, ): """ - Apply weight decay (to all rows) and update selected rows only. + Apply weight decay (to all elements) and update selected slices only. + + Args: + X: List of parameters to update + U_ortho: List of orthogonalized update submatrices + indices: List of selected indices for each matrix + base_lr: Base learning rate (for weight decay) + adjusted_lr: Adjusted learning rate (for updates) + weight_decay: Weight decay factor + select_dim: Dimension that was selected (-2 for rows, -1 for cols) - Weight decay: X = X * (1 - base_lr * weight_decay) [all rows] - Update: X[selected_rows] -= adjusted_lr * U_ortho [selected rows only] + Weight decay: X = X * (1 - base_lr * weight_decay) [all elements] + Update: X[selected_slices] -= adjusted_lr * U_ortho [selected slices only] """ # OPTIMIZATION 1: foreach for batched weight decay torch._foreach_mul_(X, 1 - base_lr * weight_decay) @@ -606,55 +681,59 @@ def dion2_post_orthogonalize( U_converted = [u.to(dtype=dtype) for u in U_ortho] # OPTIMIZATION 3: Precompute scaled updates - # This allows torch.compile to potentially fuse with index_add_ neg_lr = -adjusted_lr U_scaled = [neg_lr * u for u in U_converted] - # Apply updates to selected rows + # Apply updates to selected slices # torch.compile optimizes this loop - for x, u_scaled, indices in zip(X, U_scaled, row_indices): - x.index_add_(dim=-2, index=indices, source=u_scaled) + for x, u_scaled, idx in zip(X, U_scaled, indices): + x.index_add_(dim=select_dim, index=idx, source=u_scaled) + -# ============================================================================= -# Newton-Schulz Wrapper (unchanged) -# ============================================================================= + -def dion2_newton_schulz( - X: Tensor, - newton_schulz_func: Callable, - flatten: bool, - epsilon: Tensor, -) -> Tensor: - """Apply Newton-Schulz orthogonalization with optional flattening.""" - original_shape = X.shape - if flatten and X.ndim >= 3: - X = X.flatten(start_dim=1) - elif X.ndim >= 4: - X = X.flatten(end_dim=-3) - return newton_schulz_func(X, epsilon=epsilon).reshape(original_shape) # ============================================================================= -# Learning Rate Adjustment Functions (unchanged) +# Debug Helper: Print Selection Choice (once per configuration) # ============================================================================= -def _adjust_lr_spectral_norm(lr: Tensor, param_shape: torch.Size, flatten: bool) -> Tensor: - """Adjust LR based on spectral norm (for scale transfer).""" - if flatten: - fan_out = param_shape[0] - fan_in = math.prod(param_shape[1:]) - else: - fan_out, fan_in = param_shape[-2:] - return lr * math.sqrt(fan_out / fan_in) - +_printed_configs: set = set() -def _adjust_lr_rms_norm(lr: Tensor, param_shape: torch.Size, flatten: bool) -> Tensor: - """Adjust LR based on RMS norm (for Adam/AdamW compatibility).""" - if flatten: - fan_out = param_shape[0] - fan_in = math.prod(param_shape[1:]) - else: - fan_out, fan_in = param_shape[-2:] - return lr * 0.2 * math.sqrt(max(fan_out, fan_in)) \ No newline at end of file +def _print_selection_choice( + shape: torch.Size, + shard_dim: Optional[int], + select_dim: int, + ndim: int, +): + """Print the selection dimension choice once per unique configuration.""" + config_key = (tuple(shape), shard_dim, select_dim) + if config_key not in _printed_configs: + _printed_configs.add(config_key) + + num_rows, num_cols = shape[-2:] + select_info = "rows" if select_dim == -2 else "columns" + norm_info = "row norms" if select_dim == -2 else "col norms" + + if shard_dim is None: + mode = "DDP/Single-GPU" + shorter = "rows" if num_rows <= num_cols else "cols" + reason = f"shorter dim = {shorter} ({min(num_rows, num_cols)})" + else: + # Normalize shard_dim for display + normalized = shard_dim if shard_dim < 0 else shard_dim - ndim + if normalized == -2: + mode = "FSDP" + reason = f"row-sharded (shard_dim={shard_dim}→-2)" + elif normalized == -1: + mode = "FSDP" + reason = f"col-sharded (shard_dim={shard_dim}→-1)" + else: + mode = "FSDP batch-sharded" + shorter = "rows" if num_rows <= num_cols else "cols" + reason = f"shard_dim={shard_dim} (batch), shorter = {shorter}" + + print(f"[Dion2] Shape {tuple(shape)}: {mode}, {reason} → " + f"select top-α {select_info} by {norm_info}") diff --git a/train.py b/train.py index f8c2dcc..36dd321 100644 --- a/train.py +++ b/train.py @@ -76,7 +76,7 @@ class Hyperparameters: replicate_mesh_grad_sync: bool = False mixed_precision: bool = False adjust_lr: str = "spectral_norm" # for Muon only - + verbose: bool = False # Helper function to only print on global rank 0 MASTER_PROCESS = True @@ -453,6 +453,7 @@ def init_optimizer( weight_decay=hp.weight_decay, adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), + verbose=hp.verbose, ) elif hp.optimizer == "dion2old": if device_mesh is not None: From 4d2207933a7f69b55ca5396d1b69f2ba73adcfc1 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Wed, 17 Dec 2025 08:51:40 -0800 Subject: [PATCH 04/14] clean up --- dion/dion2.py | 6 +- dion/dion2_new.py | 406 ++++++++++++++++++++++------------------------ dion/muon.py | 2 +- train.py | 13 +- 4 files changed, 199 insertions(+), 228 deletions(-) diff --git a/dion/dion2.py b/dion/dion2.py index e8dfd3a..bb09704 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -90,8 +90,7 @@ def __init__( raise ValueError( f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." ) - - print(f"================Dion2 Old Optimizer initialized===========") + defaults = dict( lr=lr, ef_decay=ef_decay, @@ -141,7 +140,6 @@ def __init__( self._newton_schulz_func = newton_schulz_triton else: self._newton_schulz_func = zeropower_via_newtonschulz5 - @torch.no_grad() def step(self, closure=None): @@ -763,4 +761,4 @@ def dion2_update_post_orthogonalize( # Weight update U = torch._foreach_mul(U, adjusted_lr) - torch._foreach_sub_(X, U) \ No newline at end of file + torch._foreach_sub_(X, U) diff --git a/dion/dion2_new.py b/dion/dion2_new.py index 577a45c..39a6567 100644 --- a/dion/dion2_new.py +++ b/dion/dion2_new.py @@ -1,4 +1,3 @@ - import math import torch import torch.distributed as dist @@ -26,6 +25,7 @@ adjust_lr_rms_norm, ) + class Dion2(Optimizer): """ Distributed Dion2 optimizer for PyTorch FSDP2. Also compatible with DDP. @@ -39,8 +39,8 @@ class Dion2(Optimizer): fraction: Fraction of submatrix to orthogonalize per update (0 < fraction <= 1). ef_decay: Error-feedback decay factor applied to selected submatrix. betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. - weight_decay: Weight decay factor. - epsilon: Small value to avoid division by zero. + weight_decay: Weight decay factor. + epsilon: Small value to avoid division by zero. adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. @@ -51,10 +51,10 @@ class Dion2(Optimizer): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. - verbose: Whether to print debug information during updates. This prints whether rows or columns are selected for submatrix selection process. + verbose: Whether to print debug information during updates. If True, it prints whether rows or columns are selected for the submatrix selection process. Dion2 optimizer by Ahn et al.: TBD - """ + """ def __init__( self, @@ -64,7 +64,7 @@ def __init__( fraction: float = 0.25, ef_decay: float = 0.95, betas: Tuple[float, float] = (0.9, 0.95), - weight_decay: float = 0.01, + weight_decay: float = 0.01, epsilon: float = 1e-8, adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, @@ -82,7 +82,9 @@ def __init__( if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: raise ValueError(f"Invalid betas: {betas}") if adjust_lr not in ("spectral_norm", "rms_norm", None): - raise ValueError(f"Invalid adjust_lr: {adjust_lr}") + raise ValueError( + f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." + ) defaults = dict( lr=lr, @@ -90,7 +92,7 @@ def __init__( fraction=fraction, beta1=betas[0], beta2=betas[1], - weight_decay=weight_decay, + weight_decay=weight_decay, epsilon=epsilon, flatten=flatten, adjust_lr=adjust_lr, @@ -103,7 +105,7 @@ def __init__( if isinstance(distributed_mesh, DeviceMesh): if distributed_mesh.ndim != 1: raise ValueError( - f"Only 1D DeviceMesh supported, got {distributed_mesh.ndim}D." + f"Only 1D DeviceMesh supported, but got {distributed_mesh.ndim}D. For HSDP, provide the 1D sharded sub-mesh." ) self._device_rank = distributed_mesh.get_local_rank() self._world_size = distributed_mesh.size() @@ -117,21 +119,29 @@ def __init__( self._world_size = 1 self._process_group = None else: - raise TypeError(f"Invalid distributed_mesh type: {type(distributed_mesh)}") + raise TypeError( + f"Invalid distributed_mesh type: {type(distributed_mesh)}. Expected DeviceMesh or ProcessGroup." + ) self._distributed_mesh = distributed_mesh # Newton-Schulz configuration if newton_schulz_func is not None: if not callable(newton_schulz_func): - raise TypeError(f"newton_schulz_func must be callable") + raise TypeError( + f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}" + ) self._newton_schulz_func = newton_schulz_func elif use_triton: self._newton_schulz_func = newton_schulz_triton else: self._newton_schulz_func = zeropower_via_newtonschulz5 self.verbose = verbose + @torch.no_grad() def step(self, closure=None): + """ + Perform a single optimization step. + """ loss = None if closure is not None: with torch.enable_grad(): @@ -142,7 +152,10 @@ def step(self, closure=None): adamw_groups = [] for group in self.param_groups: + # Increment step group["step"] += 1 + + # Split parameter groups by algorithm algo = group["algorithm"] if algo == "dion2": dion2_groups.append(group) @@ -153,6 +166,7 @@ def step(self, closure=None): else: raise ValueError(f"Unknown algorithm: {algo}") + # Create async tasks for each algorithm dion2_tasks = self._create_dion2_tasks(dion2_groups, verbose=self.verbose) lion_tasks = self._create_lion_tasks(lion_groups) adamw_tasks = self._create_adamw_tasks(adamw_groups) @@ -164,7 +178,10 @@ def step(self, closure=None): return loss def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: - """Initialize optimizer state (identical to Muon).""" + """ + Get optimizer state for the given parameter tensor, + or lazy-initialize it if it doesn't exist. + """ state = self.state[param] if not state: state["momentum"] = torch.zeros_like(param) @@ -173,19 +190,27 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: return state def _create_dion2_tasks( - self, param_groups: List[dict], verbose: bool = False, + self, + param_groups: List[dict], + verbose: bool = False, ) -> Generator["AsyncTask", None, None]: - """Create batched Dion2 update tasks.""" + """ + Helper function to create batches of Dion2 matrices and generate + AsyncTask objects so we can process multiple batches concurrently. + """ for group in param_groups: assert group["algorithm"] == "dion2" - assert all(p.ndim >= 2 for p in group["params"]), \ - "Dion2 only supports matrix parameters." + assert all( + p.ndim >= 2 for p in group["params"] + ), "Dion2 only supports matrix parameters." group_params = [p for p in group["params"] if p.grad is not None] if not group_params: continue - # Hyperparameters as tensors for torch.compile + # Most hyperparameters as tensors for torch.compile + # Here "fraction" only determines the dimension of the submatrix + # to be orthonormalized. Hence, it doesn't need to be a tensor dion2_args = dict( lr=torch.tensor(group["lr"]), ef_decay=torch.tensor(group["ef_decay"]), @@ -197,33 +222,39 @@ def _create_dion2_tasks( device_rank=self._device_rank, world_size=self._world_size, process_group=self._process_group, - newton_schulz_func=self._newton_schulz_func, + newton_schulz_func=self._newton_schulz_func, ) - # Batch parameters by world_size (same as Muon) - for params in create_param_batches(group_params, batch_size=self._world_size): + # Create batches of parameters of size self._world_size + for params in create_param_batches( + group_params, batch_size=self._world_size + ): gradients = [p.grad for p in params] states = [self._get_or_initialize_state(p, "dion2") for p in params] momentums = [s["momentum"] for s in states] - # Determine sharding configuration - shard_dim = None + # Get sharding state for DTensor is_batch_sharded = False + is_matrix_sharded = False + sharded_mesh_dim = None + sharded_tensor_dim = None if isinstance(params[0], DTensor): if not isinstance(self._distributed_mesh, DeviceMesh): raise RuntimeError( - "Must use DeviceMesh for DTensor parameters." + "Must create optimizer with DeviceMesh if using DTensor parameters." ) - # Find sharded placements (skip size-1 mesh dims) + # Find the sharded placement and get its mesh and tensor dimensions + # Skip any Shard() placements on size-1 mesh dimension = Replicate() shard_placements = [ (i, p) for i, p in enumerate(params[0].placements) if p.is_shard() and params[0].device_mesh.size(i) > 1 ] - # Check for batch vs matrix dimension sharding + # If we don't flatten 3D matrices, we can ignore shard placements along batch dimensions + # Only keep placements that shard one of the two matrix dimensions if not group["flatten"]: matrix_dims = {params[0].ndim - 1, params[0].ndim - 2} is_batch_sharded = any( @@ -233,108 +264,138 @@ def _create_dion2_tasks( (i, p) for i, p in shard_placements if p.dim in matrix_dims ] + # Check that we have no more than 1 sharded matrix dimension + # Note that non-flattened 3D tensors can have additional sharded batch dimensions + # Flattened 3D tensors are limited to one sharded dimension out of all dimensions if len(shard_placements) == 1: - shard_dim = shard_placements[0][1].dim + is_matrix_sharded = True + sharded_mesh_dim = shard_placements[0][0] + sharded_tensor_dim = shard_placements[0][1].dim elif len(shard_placements) > 1: raise NotImplementedError( - "Multiple sharded dimensions not supported." + "Dion2 does not support parameters with multiple sharded dimensions." ) - # Verify mesh alignment - if shard_placements: - mesh_dim = shard_placements[0][0] - if params[0].device_mesh.get_group(mesh_dim) != self._process_group: - raise RuntimeError("DTensor mesh doesn't match optimizer mesh.") + # Check that the sharded mesh dimension matches optimizer's device mesh + if ( + sharded_mesh_dim is not None + and params[0].device_mesh.get_group(sharded_mesh_dim) + != self._process_group + ): + raise RuntimeError( + f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. " + f"DTensor has mesh: {params[0].device_mesh}, placements: {params[0].placements}, but optimizer was created with mesh: {self._distributed_mesh}." + ) - # Handle batch-sharded 3D tensors (each device has different matrices) - if is_batch_sharded: + # Special case for 3D tensors sharded along batch dimension + # As long as matrix dimensions are not sharded, each device will have whole matrices + # Each device already has different matrices of the batch, so we can't parallelize further + if is_batch_sharded and not is_matrix_sharded: for x, g, m in zip(params, gradients, momentums): yield AsyncTask( dion2_update_batch_async( X=[x], G=[g], M=[m], - shard_dim=None, + shard_dim=None, # No sharded matrix dim **dion2_args, verbose=verbose, ) ) + # Otherwise, we parallelize the Muon update across devices else: yield AsyncTask( dion2_update_batch_async( X=pad_batch(params, self._world_size), G=pad_batch(gradients, self._world_size), M=pad_batch(momentums, self._world_size), - shard_dim=shard_dim, + shard_dim=sharded_tensor_dim, **dion2_args, verbose=verbose, ) ) def _create_lion_tasks( - self, param_groups: List[dict] + self, + param_groups: List[dict], + algo_name: str = "lion", ) -> Generator["AsyncTask", None, None]: - """Create Lion update tasks.""" + """ + Helper function to generate AsyncTask objects for Lion updates. + """ for group in param_groups: - assert group["algorithm"] == "lion" + assert group["algorithm"] == algo_name + # Get parameters and optimizer states params = [p for p in group["params"] if p.grad is not None] if not params: continue - gradients = [p.grad for p in params] - states = [self._get_or_initialize_state(p, "lion") for p in params] + states = [self._get_or_initialize_state(p, algo_name) for p in params] momentums = [s["momentum"] for s in states] + # Wrap hyperparameters in tensors for torch.compile + lr = torch.tensor(group["lr"]) + beta1 = torch.tensor(group["beta1"]) + beta2 = torch.tensor(group["beta2"]) + weight_decay = torch.tensor(group["weight_decay"]) + yield AsyncTask( lion_update_foreach_async( X=to_local(params), G=to_local(gradients), M=to_local(momentums), - lr=torch.tensor(group["lr"]), - beta1=torch.tensor(group["beta1"]), - beta2=torch.tensor(group["beta2"]), - weight_decay=torch.tensor(group["weight_decay"]), + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, ) ) def _create_adamw_tasks( - self, param_groups: List[dict] + self, + param_groups: List[dict], + algo_name: str = "adamw", ) -> Generator["AsyncTask", None, None]: - """Create AdamW update tasks.""" + """ + Helper function to generate AsyncTask objects for AdamW updates. + """ for group in param_groups: - assert group["algorithm"] == "adamw" + assert group["algorithm"] == algo_name + # Get parameters and optimizer states params = [p for p in group["params"] if p.grad is not None] if not params: continue - gradients = [p.grad for p in params] - states = [self._get_or_initialize_state(p, "adamw") for p in params] + states = [self._get_or_initialize_state(p, algo_name) for p in params] momentums = [s["momentum"] for s in states] variances = [s["variance"] for s in states] + # Wrap hyperparameters in tensors for torch.compile + lr = torch.tensor(group["lr"]) + beta1 = torch.tensor(group["beta1"]) + beta2 = torch.tensor(group["beta2"]) + weight_decay = torch.tensor(group["weight_decay"]) + epsilon = torch.tensor(group["epsilon"]) + step = torch.tensor(group["step"]) + yield AsyncTask( adamw_update_foreach_async( X=to_local(params), G=to_local(gradients), M=to_local(momentums), V=to_local(variances), - lr=torch.tensor(group["lr"]), - beta1=torch.tensor(group["beta1"]), - beta2=torch.tensor(group["beta2"]), - weight_decay=torch.tensor(group["weight_decay"]), - step=torch.tensor(group["step"]), - epsilon=torch.tensor(group["epsilon"]), + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + step=step, + epsilon=epsilon, ) ) - -# ============================================================================= -# Core Dion2 Update Functions -# ============================================================================= - def dion2_update_batch_async( X: List[Tensor], # Parameters (DTensor or Tensor), padded to world_size G: List[Tensor], # Gradients, padded to world_size @@ -350,26 +411,26 @@ def dion2_update_batch_async( world_size: int, shard_dim: Optional[int] = None, process_group: Optional[ProcessGroup] = None, - newton_schulz_func: Optional[Callable] = None, + newton_schulz_func: Optional[Callable] = None, verbose: bool = False, ) -> Generator[None, None, None]: """ Batched Dion2 update with fractional submatrix selection. - + Algorithm: 1. Update momentum: M = M + G 2. Select top-α fraction along select_dim by L2 norm, extract submatrix 3. Apply ef_decay to selected slices in M 4. Communicate and orthogonalize only the submatrix 5. Apply weight update to corresponding slices - + Selection dimension (select_dim): - FSDP row-sharded (shard_dim=-2): select rows (select_dim=-2), row norms are local - - FSDP col-sharded (shard_dim=-1): select cols (select_dim=-1), col norms are local + - FSDP col-sharded (shard_dim=-1): select cols (select_dim=-1), col norms are local - DDP/Single GPU: select rows by default (select_dim=-2) - + Communication patterns: - - FSDP (shard_dim is not None): + - FSDP (shard_dim is not None): - All-to-all gathers selected slices to form full submatrix - Orthogonalize, then all-to-all scatter back - DDP (shard_dim is None, world_size > 1): @@ -380,9 +441,9 @@ def dion2_update_batch_async( assert len(X) == len(G) == len(M) # Determine selection dimension based on sharding - # + # # shard_dim from DTensor can be: - # - Absolute index (0, 1, 2, ...) + # - Absolute index (0, 1, 2, ...) # - Negative index (-2, -1) # - None (not sharded) # @@ -391,13 +452,13 @@ def dion2_update_batch_async( # # For FSDP: select along the sharded dimension so norms are local # For DDP/Single-GPU: select along the SHORTER dimension to reduce Newton-Schulz compute - + ndim = X[0].ndim - + if shard_dim is not None: # Convert shard_dim to normalized form relative to tensor normalized_shard_dim = shard_dim if shard_dim < 0 else shard_dim - ndim - + # Check if shard_dim corresponds to a matrix dimension (last two dims) if normalized_shard_dim == -2: # Row-sharded: select rows, compute row norms (norm over cols) @@ -436,9 +497,11 @@ def dion2_update_batch_async( assert len(X) == world_size assert process_group is not None assert isinstance(X[0], DTensor) - + recv_shards = [torch.empty_like(u) for u in U_selected] - work = dist.all_to_all(recv_shards, U_selected, group=process_group, async_op=True) + work = dist.all_to_all( + recv_shards, U_selected, group=process_group, async_op=True + ) yield work.wait() @@ -472,7 +535,10 @@ def dion2_update_batch_async( # This rank orthogonalizes the matrix at index device_rank my_submatrix = muon_update_newton_schulz( - U_selected[device_rank], newton_schulz_func, flatten=flatten, epsilon=epsilon + U_selected[device_rank], + newton_schulz_func, + flatten=flatten, + epsilon=epsilon, ) # All-gather: collect orthogonalized submatrices from all ranks @@ -507,45 +573,15 @@ def dion2_update_batch_async( # Step 4: Apply weight update to selected slices only dion2_post_orthogonalize( X=to_local(X), - U_ortho=U_ortho, + U=U_ortho, indices=indices_list, base_lr=lr, adjusted_lr=adjusted_lr, weight_decay=weight_decay, - select_dim=select_dim, + select_dim=select_dim, ) -# ============================================================================= -# Optimized Pre-Orthogonalize Function (Stacked Operations) -# ============================================================================= -# -# KEY INSIGHT: All matrices in a batch have identical shapes! -# This enables stacked/batched tensor operations instead of loops. -# -# SELECTION DIMENSION: -# - select_dim=-2 (rows): Compute row norms (norm over cols), select top-k rows -# - select_dim=-1 (cols): Compute col norms (norm over rows), select top-k cols -# -# For FSDP, select_dim matches shard_dim so norms are computed locally. -# For DDP/Single-GPU, select_dim is the shorter dimension to reduce compute. -# -# NORM CHOICE: L1 norm (sum of absolute values) -# - Cheaper than L2: no squaring or sqrt needed -# - Effective proxy for selecting high-magnitude rows/cols -# -# OPTIMIZATION 1: Stack into 3D tensor for batched ops -# ---------------------------------------------------- -# Stack (N, rows, cols) enables: -# - Single batched norm instead of N separate norms -# - Single batched topk instead of N separate topk calls -# - Single batched gather instead of N separate index_selects -# -# OPTIMIZATION 2: foreach for gradient accumulation -# ------------------------------------------------- -# Optimal for in-place batched additions. -# ============================================================================= - @torch.compile(fullgraph=True) def dion2_pre_orthogonalize( G: List[Tensor], @@ -555,168 +591,104 @@ def dion2_pre_orthogonalize( select_dim: int, ) -> Tuple[List[Tensor], List[Tensor]]: """ - Update momentum and select top-α fraction along select_dim. - - All matrices in the batch have identical shapes, enabling stacked operations. - - Args: - G: List of gradients - M: List of momentum buffers (modified in place) - fraction: Fraction of rows/cols to select - ef_decay: Decay factor for selected slices - select_dim: Dimension to select along (-2 for rows, -1 for cols) - - For each matrix M (shape: rows x cols): - 1. M += G (accumulate gradient into momentum) - 2. Compute L1 norm along the OTHER dimension - - select_dim=-2: norm over cols (dim=-1) → row norms - - select_dim=-1: norm over rows (dim=-2) → col norms - 3. Select top-k indices where k = ceil(fraction * size_of_select_dim) - 4. Extract selected slices as submatrix - 5. Apply ef_decay to selected slices in M (in-place) - - Returns: - U_selected: List of selected submatrices in bf16 for communication - indices_list: List of selected indices for each matrix + Update momentum with gradient and compute the input to orthogonalization. + More specifically, it does the following steps: + - updates the momentum with gradient + - computes the top-k indices to determine submatrices + - does in-place error-feedback decay on the selected submatrices + - output submatrices and indices + Inputs and outputs should be lists of regular Tensor, not DTensor. + This is a separate function for compatibility with torch.compile(). """ dtype = M[0].dtype - - # Determine sizes and norm dimension - # norm_dim is the dimension we compute norm OVER (the other dimension) - # select_dim is the dimension we SELECT from + + # norm_dim is the dimension we compute norm over + # select_dim is the dimension we select submatrix from num_select = M[0].size(select_dim) norm_dim = -1 if select_dim == -2 else -2 k = max(1, int(math.ceil(fraction * num_select))) - - # OPTIMIZATION 1: foreach for batched gradient accumulation - G_casted = [g.to(dtype=dtype) for g in G] - torch._foreach_add_(M, G_casted) - - # OPTIMIZATION 2: Stack for batched norm and topk - # Shape: (batch_size, rows, cols) + + G = [g.to(dtype=dtype) for g in G] + torch._foreach_add_(M, G) + M_stacked = torch.stack(M, dim=0) - + # Compute L1 norm along norm_dim (sum of absolute values) - # - If select_dim=-2 (rows): norm over dim=-1 → shape (batch, rows) - # - If select_dim=-1 (cols): norm over dim=-2 → shape (batch, cols) slice_norms = M_stacked.norm(p=1, dim=norm_dim) - + # Batched topk: indices shape (batch_size, k) _, indices = torch.topk(slice_norms, k, dim=-1, sorted=False) - - # OPTIMIZATION 3: Batched gather for slice extraction + + # Batched gather for slice extraction if select_dim == -2: - # Selecting rows: expand indices to (..., k, cols) + # Selecting rows num_cols = M[0].size(-1) indices_expanded = indices.unsqueeze(-1).expand(-1, -1, num_cols) selected_stacked = torch.gather(M_stacked, dim=-2, index=indices_expanded) else: - # Selecting cols: expand indices to (..., rows, k) + # Selecting cols num_rows = M[0].size(-2) indices_expanded = indices.unsqueeze(-2).expand(-1, num_rows, -1) selected_stacked = torch.gather(M_stacked, dim=-1, index=indices_expanded) - - # Apply ef_decay to selected slices in original M tensors - # Must loop because M tensors are separate (stack created a copy) - # Use index_copy_ with proper dimension handling for arbitrary batch dims + + # Apply error feedback decay to selected slices in original M tensors indices_list = list(indices.unbind(dim=0)) for m, idx in zip(M, indices_list): - # Extract, scale, and copy back using the correct dimension selected_slice = m.index_select(dim=select_dim, index=idx) m.index_copy_(dim=select_dim, index=idx, source=selected_slice * ef_decay) - + # Convert to bf16 and unstack for communication U_selected = list(selected_stacked.to(dtype=torch.bfloat16).unbind(dim=0)) - - return U_selected, indices_list + return U_selected, indices_list -# ============================================================================= -# Optimized Post-Orthogonalize Function -# ============================================================================= -# -# OPTIMIZATION 1: foreach for weight decay -# ---------------------------------------- -# Single fused kernel for all X *= (1 - lr * wd) operations. -# -# OPTIMIZATION 2: Batched dtype conversion -# ---------------------------------------- -# Convert all U tensors upfront for better memory planning. -# -# OPTIMIZATION 3: Loop with index_add_ (torch.compile optimized) -# -------------------------------------------------------------- -# torch.compile fuses operations within each iteration. -# ============================================================================= @torch.compile(fullgraph=True) def dion2_post_orthogonalize( X: List[Tensor], - U_ortho: List[Tensor], + U: List[Tensor], indices: List[Tensor], base_lr: Tensor, adjusted_lr: Tensor, weight_decay: Tensor, - select_dim: int, + select_dim: int, ): """ - Apply weight decay (to all elements) and update selected slices only. - - Args: - X: List of parameters to update - U_ortho: List of orthogonalized update submatrices - indices: List of selected indices for each matrix - base_lr: Base learning rate (for weight decay) - adjusted_lr: Adjusted learning rate (for updates) - weight_decay: Weight decay factor - select_dim: Dimension that was selected (-2 for rows, -1 for cols) - - Weight decay: X = X * (1 - base_lr * weight_decay) [all elements] - Update: X[selected_slices] -= adjusted_lr * U_ortho [selected slices only] + Apply weight decay and weight update after orthogonalization. + Inputs and outputs should be lists of regular Tensor, not DTensor. + This is a separate function for compatibility with torch.compile(). """ - # OPTIMIZATION 1: foreach for batched weight decay torch._foreach_mul_(X, 1 - base_lr * weight_decay) - - # OPTIMIZATION 2: Batch dtype conversion upfront + + # Convert U to match parameter dtype dtype = X[0].dtype - U_converted = [u.to(dtype=dtype) for u in U_ortho] - - # OPTIMIZATION 3: Precompute scaled updates + U = [u.to(dtype=dtype) for u in U] + # Apply weight update neg_lr = -adjusted_lr - U_scaled = [neg_lr * u for u in U_converted] - - # Apply updates to selected slices - # torch.compile optimizes this loop + U_scaled = [neg_lr * u for u in U] for x, u_scaled, idx in zip(X, U_scaled, indices): x.index_add_(dim=select_dim, index=idx, source=u_scaled) - - - - - - - -# ============================================================================= -# Debug Helper: Print Selection Choice (once per configuration) -# ============================================================================= +# A helper function to print selection chocie for each matrix +# It only prints once `verbose` is set True _printed_configs: set = set() + def _print_selection_choice( - shape: torch.Size, - shard_dim: Optional[int], + shape: torch.Size, + shard_dim: Optional[int], select_dim: int, ndim: int, ): - """Print the selection dimension choice once per unique configuration.""" config_key = (tuple(shape), shard_dim, select_dim) if config_key not in _printed_configs: _printed_configs.add(config_key) - + num_rows, num_cols = shape[-2:] select_info = "rows" if select_dim == -2 else "columns" norm_info = "row norms" if select_dim == -2 else "col norms" - + if shard_dim is None: mode = "DDP/Single-GPU" shorter = "rows" if num_rows <= num_cols else "cols" @@ -734,6 +706,8 @@ def _print_selection_choice( mode = "FSDP batch-sharded" shorter = "rows" if num_rows <= num_cols else "cols" reason = f"shard_dim={shard_dim} (batch), shorter = {shorter}" - - print(f"[Dion2] Shape {tuple(shape)}: {mode}, {reason} → " - f"select top-α {select_info} by {norm_info}") + + print( + f"[Dion2] Shape {tuple(shape)}: {mode}, {reason} → " + f"select top-α {select_info} by {norm_info}" + ) diff --git a/dion/muon.py b/dion/muon.py index 0966b84..92f444d 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -672,4 +672,4 @@ def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): if G.size(-2) > G.size(-1): X = X.mT - return X \ No newline at end of file + return X diff --git a/train.py b/train.py index 36dd321..fe8a2aa 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -#train.py below +# train.py below import argparse import math @@ -67,7 +67,7 @@ class Hyperparameters: lr: float = 0.02 mu: float = 0.95 weight_decay: float = 0.01 - rank_fraction: float = 0.125 + rank_fraction: float = 0.125 # Optimizer specific hyperparameters qr_method: str = "rcqr" @@ -76,8 +76,11 @@ class Hyperparameters: replicate_mesh_grad_sync: bool = False mixed_precision: bool = False adjust_lr: str = "spectral_norm" # for Muon only + + # For printing out selection choice in Dion2 verbose: bool = False + # Helper function to only print on global rank 0 MASTER_PROCESS = True @@ -450,7 +453,7 @@ def init_optimizer( lr=hp.lr, fraction=hp.rank_fraction, ef_decay=hp.mu, - weight_decay=hp.weight_decay, + weight_decay=hp.weight_decay, adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), verbose=hp.verbose, @@ -1035,10 +1038,6 @@ def get_lr(it): # Otherwise, checkpoint results will not be consistent optimizer.synchronize_for_checkpoint() - - - - # Save a distributed checkpoint checkpoint_manager.save(step=step) From cf0e5fe8e6d91974e37bdf125fb265b7cd0c2fd8 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 04:11:21 -0800 Subject: [PATCH 05/14] improved dion2 impl --- configs/dion2_160m.yaml | 28 +- configs/dion_160m.yaml | 21 +- configs/muon_160m.yaml | 18 +- configs/normuon_160m.yaml | 10 +- dion/__init__.py | 5 +- dion/dion2.py | 668 ++++++++++++++++------------------- dion/dion2_new.py | 713 -------------------------------------- train.py | 62 +--- 8 files changed, 343 insertions(+), 1182 deletions(-) delete mode 100644 dion/dion2_new.py diff --git a/configs/dion2_160m.yaml b/configs/dion2_160m.yaml index 9ba8055..55ff678 100644 --- a/configs/dion2_160m.yaml +++ b/configs/dion2_160m.yaml @@ -1,12 +1,12 @@ # — Model — model_dim: 768 -n_layer: 6 +n_layer: 12 n_head: 6 -sequence_length: 512 +sequence_length: 1024 # — Batching & Training — -batch_size: 128 +batch_size: 1024 device_batch_size: 32 num_iterations: 3000 @@ -14,32 +14,30 @@ num_iterations: 3000 warmup_ratio: 0.0 warmdown_ratio: 0.2 -# — Optimizer & Hyperparameters — -mu: 0.95 -weight_decay: 0.01 - # — Validation & Checkpointing — val_loss_every: 125 -val_tokens: 10485760 -save_every: 125 +val_tokens: 10485760 # — Weights & Biases logging — wandb_project_name: gpt-train wandb_job_name: null no_wandb: false -# — Distributed training — -dp_size: 1 # data‐parallel size -fs_size: 4 # FSDP size -tp_size: 1 # DO NOT USE TP for Dion2 - # — Miscellaneous flags — debug: false no_compile: false no_triton: false +# — Distributed training — +dp_size: null # data‐parallel size +fs_size: null # FSDP size +tp_size: null # DO NOT USE TP for Dion2 + +# — Optimizer & Hyperparameters — optimizer: dion2 -rank_fraction: 0.5 scalar_opt: lion +mu: 0.95 +weight_decay: 0.01 +ortho_fraction: 0.5 adjust_lr: spectral_norm lr: 0.02 diff --git a/configs/dion_160m.yaml b/configs/dion_160m.yaml index 1c00053..6759019 100644 --- a/configs/dion_160m.yaml +++ b/configs/dion_160m.yaml @@ -14,33 +14,30 @@ num_iterations: 3000 warmup_ratio: 0.0 warmdown_ratio: 0.2 -# — Optimizer & Hyperparameters — -rank_fraction: 0.125 -mu: 0.95 -weight_decay: 0.01 -oversample: 1.25 - # — Validation & Checkpointing — val_loss_every: 125 -val_tokens: 10485760 -save_every: 0 +val_tokens: 10485760 # — Weights & Biases logging — wandb_project_name: gpt-train wandb_job_name: null no_wandb: false +# — Miscellaneous flags — +debug: false +no_compile: false + # — Distributed training — dp_size: null # data‐parallel size fs_size: null # FSDP size tp_size: null # tensor‐parallel size replicate_mesh_grad_sync: true -# — Miscellaneous flags — -debug: false -no_compile: false - +# — Optimizer & Hyperparameters — optimizer: dion scalar_opt: lion +mu: 0.95 +weight_decay: 0.01 +ortho_fraction: 0.125 lr: 0.02 mixed_precision: true diff --git a/configs/muon_160m.yaml b/configs/muon_160m.yaml index a787585..a39764e 100644 --- a/configs/muon_160m.yaml +++ b/configs/muon_160m.yaml @@ -1,27 +1,22 @@ # — Model — model_dim: 768 -n_layer: 6 +n_layer: 12 n_head: 6 -sequence_length: 512 +sequence_length: 1024 # — Batching & Training — -batch_size: 128 -device_batch_size: 32 +batch_size: 1024 +device_batch_size: 32 num_iterations: 3000 # — Learning‐rate schedule — warmup_ratio: 0.0 warmdown_ratio: 0.2 -# — Optimizer & Hyperparameters — -mu: 0.95 -weight_decay: 0.01 - # — Validation & Checkpointing — val_loss_every: 125 -val_tokens: 10485760 -save_every: 0 +val_tokens: 10485760 # — Weights & Biases logging — wandb_project_name: gpt-train @@ -38,7 +33,10 @@ debug: false no_compile: false no_triton: false +# — Optimizer & Hyperparameters — optimizer: muon scalar_opt: lion adjust_lr: spectral_norm +mu: 0.95 +weight_decay: 0.01 lr: 0.02 diff --git a/configs/normuon_160m.yaml b/configs/normuon_160m.yaml index 4b14373..9607d0b 100644 --- a/configs/normuon_160m.yaml +++ b/configs/normuon_160m.yaml @@ -14,14 +14,9 @@ num_iterations: 3000 warmup_ratio: 0.0 warmdown_ratio: 0.2 -# — Optimizer & Hyperparameters — -mu: 0.95 -weight_decay: 0.01 - # — Validation & Checkpointing — val_loss_every: 125 -val_tokens: 10485760 -save_every: 0 +val_tokens: 10485760 # — Weights & Biases logging — wandb_project_name: gpt-train @@ -38,7 +33,10 @@ debug: false no_compile: false no_triton: false +# — Optimizer & Hyperparameters — optimizer: normuon scalar_opt: lion adjust_lr: spectral_norm lr: 0.02 +mu: 0.95 +weight_decay: 0.01 diff --git a/dion/__init__.py b/dion/__init__.py index f14605b..a3d18f4 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -3,7 +3,6 @@ from .dion_simple import Dion as DionSimple from .dion_reference import Dion as DionReference from .muon import Muon -from .muon_reference import Muon as MuonReference -from .dion2 import Dion2Old -from .dion2_new import Dion2 +from .muon_reference import Muon as MuonReference +from .dion2 import Dion2 from .normuon import NorMuon diff --git a/dion/dion2.py b/dion/dion2.py index bb09704..274fa7f 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -8,7 +8,6 @@ from torch.optim.optimizer import Optimizer, ParamsT from typing import Callable, Generator, List, Optional, Tuple, Union - from .newton_schulz_triton import newton_schulz_triton, zeropower_via_newtonschulz5 from .opt_utils import ( AsyncRuntime, @@ -27,15 +26,7 @@ ) -def _full_dtype_and_shape(p: Tensor) -> Tuple[torch.Size, torch.dtype, torch.device]: - if isinstance(p, DTensor): - shape = p.size() # global shape - dev = p.to_local().device - return shape, p.dtype, dev - return p.size(), p.dtype, p.device - - -class Dion2Old(Optimizer): +class Dion2(Optimizer): """ Distributed Dion2 optimizer for PyTorch FSDP2. Also compatible with DDP. @@ -43,23 +34,26 @@ class Dion2Old(Optimizer): params: Parameters for the optimizer. distributed_mesh: DeviceMesh or ProcessGroup for distributed training. Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. - lr: Base learning rate. For dion2, this will be scaled based on the matrix dimensions. + lr: Base learning rate. For Muon, this will be scaled based on the matrix dimensions. For element-wise update rules, this is the actual learning rate and no additional scaling is done. - fraction: Fraction of rows/columns to orthogonalize per update (0 < fraction <= 1). - ef_decay: Error-feedback decay factor for dion2 algorithm. + fraction: Fraction of submatrix to orthogonalize per update (0 < fraction <= 1). + ef_decay: Error-feedback decay factor applied to selected submatrix. betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. weight_decay: Weight decay factor. epsilon: Small value to avoid division by zero. - adjust_lr: How to adjust the learning rate for dion2 updates ("spectral_norm" or "rms_norm" or None). + adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. None: Do not adjust the learning rate. - flatten: Whether to flatten 3D+ tensors to 2D for dion2 updates. + flatten: Whether to flatten 3D+ tensors to 2D for Muon updates. True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + verbose: Whether to print debug information during updates. If True, it prints whether rows or columns are selected for the submatrix selection process. + + Dion2 optimizer by Ahn et al.: TBD """ def __init__( @@ -76,14 +70,15 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, + verbose: bool = False, ): - # Chenk hyperparameter + # Validate hyperparameters if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if not (0.0 < fraction <= 1.0): raise ValueError(f"fraction must be in (0, 1], got {fraction}") if ef_decay < 0.0: - raise ValueError(f"Invalid error-feedback decay (ef_decay): {ef_decay}") + raise ValueError(f"Invalid ef_decay: {ef_decay}") if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: raise ValueError(f"Invalid betas: {betas}") if adjust_lr not in ("spectral_norm", "rms_norm", None): @@ -94,6 +89,7 @@ def __init__( defaults = dict( lr=lr, ef_decay=ef_decay, + fraction=fraction, beta1=betas[0], beta2=betas[1], weight_decay=weight_decay, @@ -102,7 +98,6 @@ def __init__( adjust_lr=adjust_lr, algorithm="dion2", step=0, - fraction=fraction, ) super().__init__(params, defaults) @@ -110,7 +105,7 @@ def __init__( if isinstance(distributed_mesh, DeviceMesh): if distributed_mesh.ndim != 1: raise ValueError( - f"Only 1D DeviceMesh is supported, but got {distributed_mesh.ndim}D. For HSDP, provide the 1D sharded sub-mesh." + f"Only 1D DeviceMesh supported, but got {distributed_mesh.ndim}D. For HSDP, provide the 1D sharded sub-mesh." ) self._device_rank = distributed_mesh.get_local_rank() self._world_size = distributed_mesh.size() @@ -140,20 +135,24 @@ def __init__( self._newton_schulz_func = newton_schulz_triton else: self._newton_schulz_func = zeropower_via_newtonschulz5 + self.verbose = verbose @torch.no_grad() def step(self, closure=None): + """ + Perform a single optimization step. + """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() - # Group by optimizers dion2_groups = [] lion_groups = [] adamw_groups = [] for group in self.param_groups: + # Increment step group["step"] += 1 # Split parameter groups by algorithm @@ -168,7 +167,7 @@ def step(self, closure=None): raise ValueError(f"Unknown algorithm: {algo}") # Create async tasks for each algorithm - dion2_tasks = self._create_dion2_tasks(dion2_groups) + dion2_tasks = self._create_dion2_tasks(dion2_groups, verbose=self.verbose) lion_tasks = self._create_lion_tasks(lion_groups) adamw_tasks = self._create_adamw_tasks(adamw_groups) @@ -178,26 +177,6 @@ def step(self, closure=None): return loss - def _get_or_initialize_dion2_state_layer(self, param: Tensor) -> dict: - """ - Layer-sharded momentum state for dion2: - - 'momentum_full' lives only on the owner rank (owner is implicitly device_rank). - """ - st = self.state[param] - if "momentum_full" not in st: - st["momentum_full"] = None - return st - - def _get_or_initialize_dion2_state_local(self, param: Tensor) -> dict: - """ - Local-shard momentum state for dion2: - - Each rank keeps 'momentum_local' matching its local shard shape. - """ - st = self.state[param] - if "momentum_local" not in st: - st["momentum_local"] = torch.zeros_like(param) - return st - def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: """ Get optimizer state for the given parameter tensor, @@ -210,44 +189,32 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state["variance"] = torch.zeros_like(param) return state - def _pad_states(self, states: List[dict], n: int) -> List[dict]: - """ - Pad states to length n. Real entries get is_pad=False; padded entries get is_pad=True. - """ - out = list(states) - # Mark existing entries explicitly as not padded - for st in out: - if "is_pad" not in st: - st["is_pad"] = False - # Append padded placeholders - while len(out) < n: - out.append({"momentum_full": None, "is_pad": True}) - return out - def _create_dion2_tasks( self, param_groups: List[dict], - algo_name: str = "dion2", + verbose: bool = False, ) -> Generator["AsyncTask", None, None]: """ - Helper function to create batches of matrices and generate + Helper function to create batches of Dion2 matrices and generate AsyncTask objects so we can process multiple batches concurrently. """ for group in param_groups: - assert group["algorithm"] == algo_name + assert group["algorithm"] == "dion2" assert all( p.ndim >= 2 for p in group["params"] - ), "dion2 optimizer only supports matrix parameters." + ), "Dion2 only supports matrix parameters." - params = [p for p in group["params"] if p.grad is not None] - if not params: + group_params = [p for p in group["params"] if p.grad is not None] + if not group_params: continue - # Wrap hyperparameters as tensors for torch.compile + # Most hyperparameters as tensors for torch.compile + # Here "fraction" only determines the dimension of the submatrix + # to be orthonormalized. Hence, it doesn't need to be a tensor dion2_args = dict( lr=torch.tensor(group["lr"]), ef_decay=torch.tensor(group["ef_decay"]), - fraction=torch.tensor(group["fraction"]), + fraction=group["fraction"], weight_decay=torch.tensor(group["weight_decay"]), epsilon=torch.tensor(group["epsilon"]), flatten=group["flatten"], @@ -259,10 +226,12 @@ def _create_dion2_tasks( ) # Create batches of parameters of size self._world_size - for batch_params in create_param_batches( - params, batch_size=self._world_size + for params in create_param_batches( + group_params, batch_size=self._world_size ): - grads = [p.grad for p in batch_params] + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, "dion2") for p in params] + momentums = [s["momentum"] for s in states] # Get sharding state for DTensor is_batch_sharded = False @@ -270,7 +239,7 @@ def _create_dion2_tasks( sharded_mesh_dim = None sharded_tensor_dim = None - if isinstance(batch_params[0], DTensor): + if isinstance(params[0], DTensor): if not isinstance(self._distributed_mesh, DeviceMesh): raise RuntimeError( "Must create optimizer with DeviceMesh if using DTensor parameters." @@ -279,25 +248,20 @@ def _create_dion2_tasks( # Find the sharded placement and get its mesh and tensor dimensions # Skip any Shard() placements on size-1 mesh dimension = Replicate() shard_placements = [ - (i, pl) - for i, pl in enumerate(batch_params[0].placements) - if pl.is_shard() and batch_params[0].device_mesh.size(i) > 1 + (i, p) + for i, p in enumerate(params[0].placements) + if p.is_shard() and params[0].device_mesh.size(i) > 1 ] # If we don't flatten 3D matrices, we can ignore shard placements along batch dimensions # Only keep placements that shard one of the two matrix dimensions if not group["flatten"]: - matrix_dims = { - batch_params[0].ndim - 1, - batch_params[0].ndim - 2, - } + matrix_dims = {params[0].ndim - 1, params[0].ndim - 2} is_batch_sharded = any( - pl.dim not in matrix_dims for _, pl in shard_placements + p.dim not in matrix_dims for _, p in shard_placements ) shard_placements = [ - (i, pl) - for i, pl in shard_placements - if pl.dim in matrix_dims + (i, p) for i, p in shard_placements if p.dim in matrix_dims ] # Check that we have no more than 1 sharded matrix dimension @@ -309,13 +273,13 @@ def _create_dion2_tasks( sharded_tensor_dim = shard_placements[0][1].dim elif len(shard_placements) > 1: raise NotImplementedError( - "dion2 does not support parameters with multiple sharded dimensions." + "Dion2 does not support parameters with multiple sharded dimensions." ) # Check that the sharded mesh dimension matches optimizer's device mesh if ( sharded_mesh_dim is not None - and batch_params[0].device_mesh.get_group(sharded_mesh_dim) + and params[0].device_mesh.get_group(sharded_mesh_dim) != self._process_group ): raise RuntimeError( @@ -327,37 +291,29 @@ def _create_dion2_tasks( # As long as matrix dimensions are not sharded, each device will have whole matrices # Each device already has different matrices of the batch, so we can't parallelize further if is_batch_sharded and not is_matrix_sharded: - - # For this case, we use local momentum per shard - for x, g in zip(batch_params, grads): - st = self._get_or_initialize_dion2_state_local(x) - - # Create task for non-communicating local update + for x, g, m in zip(params, gradients, momentums): yield AsyncTask( - dion2_update_local_async( + dion2_update_batch_async( X=[x], G=[g], - STATE=st, + M=[m], + shard_dim=None, # No sharded matrix dim **dion2_args, + verbose=verbose, ) ) - continue - - # Otherwise we use layer-sharded momentum and owner mapping - states = [ - self._get_or_initialize_dion2_state_layer(p) for p in batch_params - ] - - # Create task for communicating batch update - yield AsyncTask( - dion2_update_batch_async( - X=pad_batch(batch_params, self._world_size), - G=pad_batch(grads, self._world_size), - STATES=self._pad_states(states, self._world_size), - shard_dim=sharded_tensor_dim, - **dion2_args, + # Otherwise, we parallelize the Muon update across devices + else: + yield AsyncTask( + dion2_update_batch_async( + X=pad_batch(params, self._world_size), + G=pad_batch(gradients, self._world_size), + M=pad_batch(momentums, self._world_size), + shard_dim=sharded_tensor_dim, + **dion2_args, + verbose=verbose, + ) ) - ) def _create_lion_tasks( self, @@ -367,10 +323,6 @@ def _create_lion_tasks( """ Helper function to generate AsyncTask objects for Lion updates. """ - # Check whether algo_name matches "lion" - if algo_name != "lion": - raise RuntimeError(f"lion is applied to {algo_name} groups") - for group in param_groups: assert group["algorithm"] == algo_name @@ -408,10 +360,6 @@ def _create_adamw_tasks( """ Helper function to generate AsyncTask objects for AdamW updates. """ - # Check whether algo_name matches "adamw" - if algo_name != "adamw": - raise RuntimeError(f"adamw is applied to {algo_name} groups") - for group in param_groups: assert group["algorithm"] == algo_name @@ -448,317 +396,291 @@ def _create_adamw_tasks( ) -def dion2_update_local_async( - X: List[Tensor], - G: List[Tensor], - STATE: dict, # Should put local momentum state here - lr: Tensor, - ef_decay: Tensor, - fraction: Tensor, - weight_decay: Tensor, - epsilon: Tensor, - flatten: bool, - adjust_lr: Optional[str], - newton_schulz_func: Optional[Callable] = None, -) -> Generator[None, None, None]: - assert len(X) == len(G) == 1 - x = X[0] - g = to_local(G)[0] # local shard grad - M = STATE["momentum_local"] # local shard momentum - - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, x.shape, flatten=flatten) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, x.shape, flatten=flatten) - else: - raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") - - # Error feedback on local shard and orthonormalize fraction - M.add_(g.to(dtype=M.dtype)) - O_local = fractional_orthonormalize_update( - M_full=M, - fraction=float(fraction), - ef_decay=ef_decay, - flatten=flatten, - epsilon=epsilon, - newton_schulz_func=newton_schulz_func, - ) - - # Apply update locally - dion2_update_post_orthogonalize( - X=to_local([x]), - U=[O_local], - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) - yield - - def dion2_update_batch_async( - X: List[Tensor], # DTensors or local Tensors - G: List[Tensor], # local shards (regular Tensors) - STATES: List[dict], # layer-sharded optimizer states (each has 'momentum_full') - lr: Tensor, - ef_decay: Tensor, - fraction: Tensor, - weight_decay: Tensor, - epsilon: Tensor, - flatten: bool, - adjust_lr: Optional[str], - device_rank: int, - world_size: int, - shard_dim: Optional[int] = None, + X: List[Tensor], # Model weights (modified in place) + G: List[Tensor], # Gradient + M: List[Tensor], # Momentum buffer (modified in place) + lr: Tensor, # Learning rate (scalar tensor) + ef_decay: Tensor, # Error-feedback factor (scalar tensor) + fraction: float, # Fraction of submatrix to orthogonalize (0 < fraction <= 1) + weight_decay: Tensor, # Weight decay (scalar tensor) + epsilon: Tensor, # Epsilon (scalar tensor) + flatten: bool, # Whether to flatten 3D+ tensors to 2D + adjust_lr: Optional[str], # How to adjust learning rate + device_rank: int, # Rank of the current device + world_size: int, # Total number of devices to parallelize over + shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, + verbose: bool = False, ) -> Generator[None, None, None]: """ - dion2 layer-sharded path: - - Matrix-dim sharded: all_to_all shards <-> full, compute once on owner, all_to_all back. - - Replicated/unsharded: compute once on owner (index=device_rank), all_gather dense updates. - - Single-GPU (batch size=1): compute once on owner, apply locally. + Batched version of Dion2 update. Batch size should be equal to number of GPUs. + All tensors in a batch should have identical shape, sharding, and dtype. + Identical hyperparameters are used for all tensors in the batch. """ + assert len(X) == len(G) + assert len(X) == len(M) - # Ownership-by-index: owner of batch index j is rank j. - assert 0 <= device_rank < world_size - assert len(X) == len(STATES) == world_size # Guaranteed by padding - - # convert gradients to local shards - G_local = to_local(G) - frac = float(fraction) + # Determine selection dimension based on sharding and tensor shape: + # For sharded matrices, we align select_dim with shard_dim + # For unsharded matrices (DDP or single-GPU), we select the shorter dimension + ndim = X[0].ndim + select_dim = None - # Compute adjusted lr from global shape (matrix dims last) - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) - else: - raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + if shard_dim is not None: + # Normalize shard_dim to negative indexing for unified treatment + shard_dim = shard_dim if shard_dim < 0 else shard_dim - ndim + if shard_dim == -2: + select_dim = -2 # Row-sharded + elif shard_dim == -1: + select_dim = -1 # Column-sharded + + # Fall-back to shorter dimension when DDP, Single-GPU, or batch-sharded + if select_dim is None: + num_rows, num_cols = X[0].shape[-2:] + select_dim = -2 if num_rows <= num_cols else -1 + + # Print how the selection choice based on shard_dim and tensor shape + if verbose: + _print_selection_choice(X[0].shape, shard_dim, select_dim, ndim) + + # Update momentum and select top-α fraction along select_dim + U_selected, indices_list = dion2_pre_orthogonalize( + G=to_local(G), + M=to_local(M), + fraction=fraction, + ef_decay=ef_decay, + select_dim=select_dim, + ) - # Matrix-dimension sharded path + # Get one whole matrix for each device to orthogonalize if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + # https://www.essential.ai/blog/infra assert len(X) == world_size, "Batch size must equal world size" assert ( process_group is not None ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" - assert X[0].size(shard_dim) % world_size == 0, ( - f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} " - f"is not divisible by world size {world_size}." - ) + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." - # Shards -> single full gradient on owner (bf16 comm) - G_bf16 = [g.to(dtype=torch.bfloat16) for g in G_local] - recv_shards = [torch.empty_like(g) for g in G_bf16] - work = dist.all_to_all(recv_shards, G_bf16, group=process_group, async_op=True) + # Allocate buffers to receive shards of one whole submatrix from other devices + recv_shards = [torch.empty_like(u) for u in U_selected] + work = dist.all_to_all( + recv_shards, U_selected, group=process_group, async_op=True + ) yield work.wait() - full_grad_bf16 = torch.cat(recv_shards, dim=shard_dim) + # Concatentate shards to form a whole matrix to orthogonalize + # Only submatrix is orthogonalized! + full_submatrix = torch.cat(recv_shards, dim=select_dim) + full_submatrix = muon_update_newton_schulz( + full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon + ) - # Ownership-by-index contract: - # For layer-sharded path, the "owner" of batch index j is rank j (device_rank). - owner_state = STATES[device_rank] - owner_is_pad = owner_state.get("is_pad", False) + # Split result back into shards + # Contiguous is needed for all-to-all to work correctly + send_shards = [ + t.contiguous() + for t in torch.tensor_split(full_submatrix, world_size, dim=select_dim) + ] - # Build the shards to send (bf16), but only do actual work if not "pad" - if owner_is_pad: - # For pads, do NOT allocate momentum_full or run NS. - # Just return zero shards. - send_shards = [torch.zeros_like(g) for g in G_bf16] # bf16 payloads - else: - # Non-pads: allocate/accumulate, run NS, split to bf16 shards - if owner_state["momentum_full"] is None: - full_shape, param_dtype, param_device = _full_dtype_and_shape( - X[device_rank] - ) - owner_state["momentum_full"] = torch.zeros( - full_shape, dtype=param_dtype, device=param_device - ) - M_full = owner_state["momentum_full"] - full_grad = full_grad_bf16.to(dtype=M_full.dtype) - M_full.add_(full_grad) - - O_full = fractional_orthonormalize_update( - M_full=M_full, - fraction=frac, - ef_decay=ef_decay, - flatten=flatten, - epsilon=epsilon, - newton_schulz_func=newton_schulz_func, - ) - # Split back to shards - send_shards = [ - t.contiguous().to(torch.bfloat16) - for t in torch.tensor_split(O_full, world_size, dim=shard_dim) - ] - - # All-to-all back to shards - U = [torch.empty_like(g) for g in G_bf16] - work = dist.all_to_all(U, send_shards, group=process_group, async_op=True) + # Redistribute the orthogonalized tensor back to original layout + U_ortho = [torch.empty_like(u) for u in U_selected] + work = dist.all_to_all(U_ortho, send_shards, group=process_group, async_op=True) yield work.wait() - # Replicated / unsharded - else: - # Owner index is device_rank - x_owner = X[device_rank] - g_owner = G_local[device_rank] - st_owner = STATES[device_rank] - owner_is_pad = st_owner.get("is_pad", False) - - # Check whether we are in multi-GPU setting - multi_gpu = process_group is not None and world_size > 1 - - if multi_gpu: - # For pads, do not allocate momentum_full or run NS. - if owner_is_pad: - payload_bf16 = torch.zeros_like( - g_owner, dtype=torch.bfloat16 - ).contiguous() - # Non-pads: allocate/accumulate, run NS, prepare bf16 payload - else: - if st_owner["momentum_full"] is None: - full_shape, param_dtype, param_device = _full_dtype_and_shape( - x_owner - ) - st_owner["momentum_full"] = torch.zeros( - full_shape, dtype=param_dtype, device=param_device - ) - M_full = st_owner["momentum_full"] - M_full.add_(g_owner.to(dtype=M_full.dtype)) - - O_full = fractional_orthonormalize_update( - M_full=M_full, - fraction=frac, - ef_decay=ef_decay, - flatten=flatten, - epsilon=epsilon, - newton_schulz_func=newton_schulz_func, - ) - payload_bf16 = O_full.to(dtype=torch.bfloat16).contiguous() + # Matrices are not sharded, so we can distribute the batch across different devices + # Get a single matrix of the batch corresponding to this device + elif len(U_selected) > 1: + assert len(U_selected) == world_size, "Batch size must equal world size" + assert process_group is not None + + single_matrix = U_selected[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_ortho = muon_update_newton_schulz( + single_matrix, + newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # All-gather the computed updates - U = [torch.empty_like(payload_bf16) for _ in range(world_size)] - work = dist.all_gather(U, payload_bf16, group=process_group, async_op=True) - yield - work.wait() + # Allocate empty tensors to receive updates from other devices + U_ortho = [torch.empty_like(u) for u in U_selected] + # All gather orthogonalized results from other devices into buffer + work = dist.all_gather( + U_ortho, single_ortho.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() - else: - # Single-GPU case: produce local update directly. - # No padding case handling required. - if st_owner["momentum_full"] is None: - full_shape, param_dtype, param_device = _full_dtype_and_shape(x_owner) - st_owner["momentum_full"] = torch.zeros( - full_shape, dtype=param_dtype, device=param_device - ) - M_full = st_owner["momentum_full"] - M_full.add_(g_owner.to(dtype=M_full.dtype)) - - O_full = fractional_orthonormalize_update( - M_full=M_full, - fraction=frac, - ef_decay=ef_decay, - flatten=flatten, - epsilon=epsilon, - newton_schulz_func=newton_schulz_func, + # Single tensor with no sharded dimension. This happens in 2 cases: + # - Running on a single GPU + # - 3D+ tensors sharded along a batch dimension (different whole matrices per device) + else: + assert len(U_selected) == 1 + U_ortho = [ + muon_update_newton_schulz( + U_selected[0], newton_schulz_func, flatten=flatten, epsilon=epsilon ) - U = [O_full] - - # Ensure foreach dtypes match parameter shards for the update - X_local = to_local(X) - U = [u.to(dtype=xi.dtype) for u, xi in zip(U, X_local)] + ] - dion2_update_post_orthogonalize( - X=X_local, - U=U, + # Compute scaled learning rate + # Do this before to_local(X) because we use the full tensor shape, not the shard shape + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr: {adjust_lr}") + + # Update model parameters with orthogonalized output + # Weight update is applied to selected slices only + dion2_post_orthogonalize( + X=to_local(X), + U=U_ortho, + indices=indices_list, base_lr=lr, adjusted_lr=adjusted_lr, weight_decay=weight_decay, + select_dim=select_dim, ) -def make_work_view(M: Tensor) -> Tuple[Tensor, bool]: - I, J = M.size(-2), M.size(-1) - if I < J: - return M.mT, True - return M, False +@torch.compile(fullgraph=True) +def dion2_pre_orthogonalize( + G: List[Tensor], + M: List[Tensor], + fraction: Tensor, + ef_decay: Tensor, + select_dim: int, +) -> Tuple[List[Tensor], List[Tensor]]: + """ + Update momentum with gradient and compute the input to orthogonalization. + More specifically, it does the following steps: + - updates the momentum with gradient + - computes the top-k indices to determine submatrices + - does in-place error-feedback decay on the selected submatrices + - output submatrices and indices + Inputs and outputs should be lists of regular Tensor, not DTensor. + This is a separate function for compatibility with torch.compile(). + """ + dtype = M[0].dtype + # norm_dim is the dimension we compute norm over + # select_dim is the dimension we select submatrix from + num_select = M[0].size(select_dim) + norm_dim = -1 if select_dim == -2 else -2 + k = max(1, int(math.ceil(fraction * num_select))) -def fractional_orthonormalize_update( - M_full: Tensor, - fraction: float, - ef_decay: Tensor, - flatten: bool, - epsilon: Tensor, - newton_schulz_func: Callable, -) -> Tensor: - M_work, transposed = make_work_view(M_full) - I, J = M_work.size(-2), M_work.size(-1) - if fraction >= 1.0: - # Full orthonormalization - ortho_update = muon_update_newton_schulz( - M_work, newton_schulz_func, flatten=flatten, epsilon=epsilon - ) - M_work.mul_(ef_decay) + # Update momentum: M = M + G + G = [g.to(dtype=dtype) for g in G] + torch._foreach_add_(M, G) + + M_stacked = torch.stack(M, dim=0) + + # Compute L1 norm along norm_dim (sum of absolute values) + slice_norms = M_stacked.norm(p=1, dim=norm_dim) + + # Batched topk: indices shape (batch_size, k) + _, indices = torch.topk(slice_norms, k, dim=-1, sorted=False) + + # Batched gather for slice extraction + if select_dim == -2: + # Selecting rows + num_cols = M[0].size(-1) + indices_expanded = indices.unsqueeze(-1).expand(-1, -1, num_cols) + selected_stacked = torch.gather(M_stacked, dim=-2, index=indices_expanded) else: - # Fractional orthonormalization - k = int(math.ceil(fraction * J)) - ortho_update = topk_and_orthonormalize( - M_work, - ef_decay=ef_decay, - k=k, - flatten=flatten, - epsilon=epsilon, - newton_schulz_func=newton_schulz_func, - ) - return ortho_update.mT.contiguous() if transposed else ortho_update + # Selecting cols + num_rows = M[0].size(-2) + indices_expanded = indices.unsqueeze(-2).expand(-1, num_rows, -1) + selected_stacked = torch.gather(M_stacked, dim=-1, index=indices_expanded) + # Apply error feedback decay to selected slices in original M tensors + indices_list = list(indices.unbind(dim=0)) + for m, idx in zip(M, indices_list): + selected_slice = m.index_select(dim=select_dim, index=idx) + m.index_copy_(dim=select_dim, index=idx, source=selected_slice * ef_decay) -def topk_and_orthonormalize( - M_work: Tensor, - ef_decay: Tensor, - k: int, - flatten: bool, - epsilon: Tensor, - newton_schulz_func, -) -> Tensor: - """ """ - # Compute the top-k columns by L1 norm - alpha = M_work.abs().sum(dim=-2) # [J] - K = torch.topk(alpha, k, sorted=False).indices # [k] - # Select and orthonormalize - M_sel = torch.index_select(M_work, dim=-1, index=K) # [I, k] - O_sel = muon_update_newton_schulz( - M_sel, newton_schulz_func, flatten=flatten, epsilon=epsilon - ) - # In-place error-feedback decay only on selected columns: - M_work[..., K] *= ef_decay - # Construct the full update matrix - O_full = torch.zeros_like(M_work, dtype=O_sel.dtype) - O_full.index_copy_(dim=-1, index=K, source=O_sel) - return O_full + # Convert to bf16 and unstack for communication + U_selected = list(selected_stacked.to(dtype=torch.bfloat16).unbind(dim=0)) + + return U_selected, indices_list -def dion2_update_post_orthogonalize( +@torch.compile(fullgraph=True) +def dion2_post_orthogonalize( X: List[Tensor], U: List[Tensor], + indices: List[Tensor], base_lr: Tensor, adjusted_lr: Tensor, weight_decay: Tensor, + select_dim: int, ): """ Apply weight decay and weight update after orthogonalization. Inputs and outputs should be lists of regular Tensor, not DTensor. This is a separate function for compatibility with torch.compile(). """ - # Apply weight decay torch._foreach_mul_(X, 1 - base_lr * weight_decay) - # Weight update - U = torch._foreach_mul(U, adjusted_lr) - torch._foreach_sub_(X, U) + # Convert U to match parameter dtype + dtype = X[0].dtype + U = [u.to(dtype=dtype) for u in U] + # Apply weight update + neg_lr = -adjusted_lr + U_scaled = [neg_lr * u for u in U] + for x, u_scaled, idx in zip(X, U_scaled, indices): + x.index_add_(dim=select_dim, index=idx, source=u_scaled) + + +# A helper function to print selection chocie for each matrix +# It only prints once `verbose` is set True +_printed_configs: set = set() + +def _print_selection_choice( + shape: torch.Size, + shard_dim: Optional[int], + select_dim: int, + ndim: int, +): + config_key = (tuple(shape), shard_dim, select_dim) + if config_key not in _printed_configs: + _printed_configs.add(config_key) + + num_rows, num_cols = shape[-2:] + select_info = "rows" if select_dim == -2 else "columns" + norm_info = "row norms" if select_dim == -2 else "col norms" + + if shard_dim is None: + mode = "DDP/Single-GPU" + shorter = "rows" if num_rows <= num_cols else "cols" + reason = f"shorter dim = {shorter} ({min(num_rows, num_cols)})" + else: + # Normalize shard_dim for display + normalized = shard_dim if shard_dim < 0 else shard_dim - ndim + if normalized == -2: + mode = "FSDP" + reason = f"row-sharded (shard_dim={shard_dim}→-2)" + elif normalized == -1: + mode = "FSDP" + reason = f"col-sharded (shard_dim={shard_dim}→-1)" + else: + mode = "FSDP batch-sharded" + shorter = "rows" if num_rows <= num_cols else "cols" + reason = f"shard_dim={shard_dim} (batch), shorter = {shorter}" + + print( + f"[Dion2] Shape {tuple(shape)}: {mode}, {reason} → " + f"select top-α {select_info} by {norm_info}" + ) diff --git a/dion/dion2_new.py b/dion/dion2_new.py deleted file mode 100644 index 39a6567..0000000 --- a/dion/dion2_new.py +++ /dev/null @@ -1,713 +0,0 @@ -import math -import torch -import torch.distributed as dist -from itertools import chain -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.distributed.tensor import DeviceMesh, DTensor -from torch.optim.optimizer import Optimizer, ParamsT -from typing import Callable, Generator, List, Optional, Tuple, Union - -from .newton_schulz_triton import newton_schulz_triton, zeropower_via_newtonschulz5 -from .opt_utils import ( - AsyncRuntime, - AsyncTask, - create_param_batches, - pad_batch, - to_local, -) -from .scalar_opts import adamw_update_foreach_async, lion_update_foreach_async - -# Reuse Muon's helper functions -from .muon import ( - muon_update_newton_schulz, - adjust_lr_spectral_norm, - adjust_lr_rms_norm, -) - - -class Dion2(Optimizer): - """ - Distributed Dion2 optimizer for PyTorch FSDP2. Also compatible with DDP. - - Args: - params: Parameters for the optimizer. - distributed_mesh: DeviceMesh or ProcessGroup for distributed training. - Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. - lr: Base learning rate. For Muon, this will be scaled based on the matrix dimensions. - For element-wise update rules, this is the actual learning rate and no additional scaling is done. - fraction: Fraction of submatrix to orthogonalize per update (0 < fraction <= 1). - ef_decay: Error-feedback decay factor applied to selected submatrix. - betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. - weight_decay: Weight decay factor. - epsilon: Small value to avoid division by zero. - adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). - "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. - "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. - None: Do not adjust the learning rate. - flatten: Whether to flatten 3D+ tensors to 2D for Muon updates. - True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. - False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. - use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. - newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. - Signature is `func(input: Tensor, epsilon: float) -> Tensor`. - verbose: Whether to print debug information during updates. If True, it prints whether rows or columns are selected for the submatrix selection process. - - Dion2 optimizer by Ahn et al.: TBD - """ - - def __init__( - self, - params: ParamsT, - distributed_mesh: Optional[Union[DeviceMesh, ProcessGroup]] = None, - lr: float = 0.01, - fraction: float = 0.25, - ef_decay: float = 0.95, - betas: Tuple[float, float] = (0.9, 0.95), - weight_decay: float = 0.01, - epsilon: float = 1e-8, - adjust_lr: Optional[str] = "spectral_norm", - flatten: bool = False, - use_triton: bool = False, - newton_schulz_func: Optional[Callable] = None, - verbose: bool = False, - ): - # Validate hyperparameters - if lr < 0.0: - raise ValueError(f"Invalid learning rate: {lr}") - if not (0.0 < fraction <= 1.0): - raise ValueError(f"fraction must be in (0, 1], got {fraction}") - if ef_decay < 0.0: - raise ValueError(f"Invalid ef_decay: {ef_decay}") - if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: - raise ValueError(f"Invalid betas: {betas}") - if adjust_lr not in ("spectral_norm", "rms_norm", None): - raise ValueError( - f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." - ) - - defaults = dict( - lr=lr, - ef_decay=ef_decay, - fraction=fraction, - beta1=betas[0], - beta2=betas[1], - weight_decay=weight_decay, - epsilon=epsilon, - flatten=flatten, - adjust_lr=adjust_lr, - algorithm="dion2", - step=0, - ) - super().__init__(params, defaults) - - # Distributed configuration - if isinstance(distributed_mesh, DeviceMesh): - if distributed_mesh.ndim != 1: - raise ValueError( - f"Only 1D DeviceMesh supported, but got {distributed_mesh.ndim}D. For HSDP, provide the 1D sharded sub-mesh." - ) - self._device_rank = distributed_mesh.get_local_rank() - self._world_size = distributed_mesh.size() - self._process_group = distributed_mesh.get_group() - elif isinstance(distributed_mesh, ProcessGroup): - self._device_rank = dist.get_rank(distributed_mesh) - self._world_size = dist.get_world_size(distributed_mesh) - self._process_group = distributed_mesh - elif distributed_mesh is None: - self._device_rank = 0 - self._world_size = 1 - self._process_group = None - else: - raise TypeError( - f"Invalid distributed_mesh type: {type(distributed_mesh)}. Expected DeviceMesh or ProcessGroup." - ) - self._distributed_mesh = distributed_mesh - - # Newton-Schulz configuration - if newton_schulz_func is not None: - if not callable(newton_schulz_func): - raise TypeError( - f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}" - ) - self._newton_schulz_func = newton_schulz_func - elif use_triton: - self._newton_schulz_func = newton_schulz_triton - else: - self._newton_schulz_func = zeropower_via_newtonschulz5 - self.verbose = verbose - - @torch.no_grad() - def step(self, closure=None): - """ - Perform a single optimization step. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - dion2_groups = [] - lion_groups = [] - adamw_groups = [] - - for group in self.param_groups: - # Increment step - group["step"] += 1 - - # Split parameter groups by algorithm - algo = group["algorithm"] - if algo == "dion2": - dion2_groups.append(group) - elif algo == "lion": - lion_groups.append(group) - elif algo == "adamw": - adamw_groups.append(group) - else: - raise ValueError(f"Unknown algorithm: {algo}") - - # Create async tasks for each algorithm - dion2_tasks = self._create_dion2_tasks(dion2_groups, verbose=self.verbose) - lion_tasks = self._create_lion_tasks(lion_groups) - adamw_tasks = self._create_adamw_tasks(adamw_groups) - - all_tasks = chain(dion2_tasks, lion_tasks, adamw_tasks) - runtime = AsyncRuntime(all_tasks, max_concurrent_tasks=3) - runtime.run() - - return loss - - def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: - """ - Get optimizer state for the given parameter tensor, - or lazy-initialize it if it doesn't exist. - """ - state = self.state[param] - if not state: - state["momentum"] = torch.zeros_like(param) - if algo == "adamw": - state["variance"] = torch.zeros_like(param) - return state - - def _create_dion2_tasks( - self, - param_groups: List[dict], - verbose: bool = False, - ) -> Generator["AsyncTask", None, None]: - """ - Helper function to create batches of Dion2 matrices and generate - AsyncTask objects so we can process multiple batches concurrently. - """ - for group in param_groups: - assert group["algorithm"] == "dion2" - assert all( - p.ndim >= 2 for p in group["params"] - ), "Dion2 only supports matrix parameters." - - group_params = [p for p in group["params"] if p.grad is not None] - if not group_params: - continue - - # Most hyperparameters as tensors for torch.compile - # Here "fraction" only determines the dimension of the submatrix - # to be orthonormalized. Hence, it doesn't need to be a tensor - dion2_args = dict( - lr=torch.tensor(group["lr"]), - ef_decay=torch.tensor(group["ef_decay"]), - fraction=group["fraction"], - weight_decay=torch.tensor(group["weight_decay"]), - epsilon=torch.tensor(group["epsilon"]), - flatten=group["flatten"], - adjust_lr=group["adjust_lr"], - device_rank=self._device_rank, - world_size=self._world_size, - process_group=self._process_group, - newton_schulz_func=self._newton_schulz_func, - ) - - # Create batches of parameters of size self._world_size - for params in create_param_batches( - group_params, batch_size=self._world_size - ): - gradients = [p.grad for p in params] - states = [self._get_or_initialize_state(p, "dion2") for p in params] - momentums = [s["momentum"] for s in states] - - # Get sharding state for DTensor - is_batch_sharded = False - is_matrix_sharded = False - sharded_mesh_dim = None - sharded_tensor_dim = None - - if isinstance(params[0], DTensor): - if not isinstance(self._distributed_mesh, DeviceMesh): - raise RuntimeError( - "Must create optimizer with DeviceMesh if using DTensor parameters." - ) - - # Find the sharded placement and get its mesh and tensor dimensions - # Skip any Shard() placements on size-1 mesh dimension = Replicate() - shard_placements = [ - (i, p) - for i, p in enumerate(params[0].placements) - if p.is_shard() and params[0].device_mesh.size(i) > 1 - ] - - # If we don't flatten 3D matrices, we can ignore shard placements along batch dimensions - # Only keep placements that shard one of the two matrix dimensions - if not group["flatten"]: - matrix_dims = {params[0].ndim - 1, params[0].ndim - 2} - is_batch_sharded = any( - p.dim not in matrix_dims for _, p in shard_placements - ) - shard_placements = [ - (i, p) for i, p in shard_placements if p.dim in matrix_dims - ] - - # Check that we have no more than 1 sharded matrix dimension - # Note that non-flattened 3D tensors can have additional sharded batch dimensions - # Flattened 3D tensors are limited to one sharded dimension out of all dimensions - if len(shard_placements) == 1: - is_matrix_sharded = True - sharded_mesh_dim = shard_placements[0][0] - sharded_tensor_dim = shard_placements[0][1].dim - elif len(shard_placements) > 1: - raise NotImplementedError( - "Dion2 does not support parameters with multiple sharded dimensions." - ) - - # Check that the sharded mesh dimension matches optimizer's device mesh - if ( - sharded_mesh_dim is not None - and params[0].device_mesh.get_group(sharded_mesh_dim) - != self._process_group - ): - raise RuntimeError( - f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. " - f"DTensor has mesh: {params[0].device_mesh}, placements: {params[0].placements}, but optimizer was created with mesh: {self._distributed_mesh}." - ) - - # Special case for 3D tensors sharded along batch dimension - # As long as matrix dimensions are not sharded, each device will have whole matrices - # Each device already has different matrices of the batch, so we can't parallelize further - if is_batch_sharded and not is_matrix_sharded: - for x, g, m in zip(params, gradients, momentums): - yield AsyncTask( - dion2_update_batch_async( - X=[x], - G=[g], - M=[m], - shard_dim=None, # No sharded matrix dim - **dion2_args, - verbose=verbose, - ) - ) - # Otherwise, we parallelize the Muon update across devices - else: - yield AsyncTask( - dion2_update_batch_async( - X=pad_batch(params, self._world_size), - G=pad_batch(gradients, self._world_size), - M=pad_batch(momentums, self._world_size), - shard_dim=sharded_tensor_dim, - **dion2_args, - verbose=verbose, - ) - ) - - def _create_lion_tasks( - self, - param_groups: List[dict], - algo_name: str = "lion", - ) -> Generator["AsyncTask", None, None]: - """ - Helper function to generate AsyncTask objects for Lion updates. - """ - for group in param_groups: - assert group["algorithm"] == algo_name - - # Get parameters and optimizer states - params = [p for p in group["params"] if p.grad is not None] - if not params: - continue - gradients = [p.grad for p in params] - states = [self._get_or_initialize_state(p, algo_name) for p in params] - momentums = [s["momentum"] for s in states] - - # Wrap hyperparameters in tensors for torch.compile - lr = torch.tensor(group["lr"]) - beta1 = torch.tensor(group["beta1"]) - beta2 = torch.tensor(group["beta2"]) - weight_decay = torch.tensor(group["weight_decay"]) - - yield AsyncTask( - lion_update_foreach_async( - X=to_local(params), - G=to_local(gradients), - M=to_local(momentums), - lr=lr, - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - ) - ) - - def _create_adamw_tasks( - self, - param_groups: List[dict], - algo_name: str = "adamw", - ) -> Generator["AsyncTask", None, None]: - """ - Helper function to generate AsyncTask objects for AdamW updates. - """ - for group in param_groups: - assert group["algorithm"] == algo_name - - # Get parameters and optimizer states - params = [p for p in group["params"] if p.grad is not None] - if not params: - continue - gradients = [p.grad for p in params] - states = [self._get_or_initialize_state(p, algo_name) for p in params] - momentums = [s["momentum"] for s in states] - variances = [s["variance"] for s in states] - - # Wrap hyperparameters in tensors for torch.compile - lr = torch.tensor(group["lr"]) - beta1 = torch.tensor(group["beta1"]) - beta2 = torch.tensor(group["beta2"]) - weight_decay = torch.tensor(group["weight_decay"]) - epsilon = torch.tensor(group["epsilon"]) - step = torch.tensor(group["step"]) - - yield AsyncTask( - adamw_update_foreach_async( - X=to_local(params), - G=to_local(gradients), - M=to_local(momentums), - V=to_local(variances), - lr=lr, - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - step=step, - epsilon=epsilon, - ) - ) - - -def dion2_update_batch_async( - X: List[Tensor], # Parameters (DTensor or Tensor), padded to world_size - G: List[Tensor], # Gradients, padded to world_size - M: List[Tensor], # Momentum buffers (modified in place), padded to world_size - lr: Tensor, - ef_decay: Tensor, - fraction: float, - weight_decay: Tensor, - epsilon: Tensor, - flatten: bool, - adjust_lr: Optional[str], - device_rank: int, - world_size: int, - shard_dim: Optional[int] = None, - process_group: Optional[ProcessGroup] = None, - newton_schulz_func: Optional[Callable] = None, - verbose: bool = False, -) -> Generator[None, None, None]: - """ - Batched Dion2 update with fractional submatrix selection. - - Algorithm: - 1. Update momentum: M = M + G - 2. Select top-α fraction along select_dim by L2 norm, extract submatrix - 3. Apply ef_decay to selected slices in M - 4. Communicate and orthogonalize only the submatrix - 5. Apply weight update to corresponding slices - - Selection dimension (select_dim): - - FSDP row-sharded (shard_dim=-2): select rows (select_dim=-2), row norms are local - - FSDP col-sharded (shard_dim=-1): select cols (select_dim=-1), col norms are local - - DDP/Single GPU: select rows by default (select_dim=-2) - - Communication patterns: - - FSDP (shard_dim is not None): - - All-to-all gathers selected slices to form full submatrix - - Orthogonalize, then all-to-all scatter back - - DDP (shard_dim is None, world_size > 1): - - Each rank orthogonalizes one matrix from the batch - - All-gather to distribute results - - Single GPU: direct computation - """ - assert len(X) == len(G) == len(M) - - # Determine selection dimension based on sharding - # - # shard_dim from DTensor can be: - # - Absolute index (0, 1, 2, ...) - # - Negative index (-2, -1) - # - None (not sharded) - # - # We need to map this to select_dim which is always -2 (rows) or -1 (cols) - # relative to the last two dimensions (the matrix dimensions). - # - # For FSDP: select along the sharded dimension so norms are local - # For DDP/Single-GPU: select along the SHORTER dimension to reduce Newton-Schulz compute - - ndim = X[0].ndim - - if shard_dim is not None: - # Convert shard_dim to normalized form relative to tensor - normalized_shard_dim = shard_dim if shard_dim < 0 else shard_dim - ndim - - # Check if shard_dim corresponds to a matrix dimension (last two dims) - if normalized_shard_dim == -2: - # Row-sharded: select rows, compute row norms (norm over cols) - select_dim = -2 - elif normalized_shard_dim == -1: - # Col-sharded: select cols, compute col norms (norm over rows) - select_dim = -1 - else: - # Batch dimension sharded: not a matrix dim, fall back to shorter dim - num_rows, num_cols = X[0].shape[-2:] - select_dim = -2 if num_rows <= num_cols else -1 - else: - # DDP/Single-GPU: choose shorter dimension to reduce Newton-Schulz compute - num_rows, num_cols = X[0].shape[-2:] - select_dim = -2 if num_rows <= num_cols else -1 - - # Debug: Print selection choice (only on first call per parameter shape) - if verbose: - _print_selection_choice(X[0].shape, shard_dim, select_dim, ndim) - - # Step 1: Update momentum and select top-α fraction along select_dim - # All matrices in batch have identical shapes, enabling stacked operations - U_selected, indices_list = dion2_pre_orthogonalize( - G=to_local(G), - M=to_local(M), - fraction=fraction, - ef_decay=ef_decay, - select_dim=select_dim, - ) - - # Step 2: Communicate and orthogonalize selected submatrices - # ------------------------------------------------------------------------- - # FSDP path: all-to-all - # ------------------------------------------------------------------------- - if shard_dim is not None: - assert len(X) == world_size - assert process_group is not None - assert isinstance(X[0], DTensor) - - recv_shards = [torch.empty_like(u) for u in U_selected] - work = dist.all_to_all( - recv_shards, U_selected, group=process_group, async_op=True - ) - yield - work.wait() - - # Concatenate along selection dimension to form full selected submatrix - # select_dim matches shard_dim, so we concatenate along that dimension - full_submatrix = torch.cat(recv_shards, dim=select_dim) - - # Orthogonalize the full selected submatrix - full_submatrix = muon_update_newton_schulz( - full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon - ) - - # Split back into shards along the same dimension - send_shards = [ - t.contiguous() - for t in torch.tensor_split(full_submatrix, world_size, dim=select_dim) - ] - - # All-to-all: scatter orthogonalized shards back to original owners - U_ortho = [torch.empty_like(u) for u in U_selected] - work = dist.all_to_all(U_ortho, send_shards, group=process_group, async_op=True) - yield - work.wait() - - # ------------------------------------------------------------------------- - # DDP path: all-gather - # ------------------------------------------------------------------------- - elif len(U_selected) > 1: - assert len(U_selected) == world_size - assert process_group is not None - - # This rank orthogonalizes the matrix at index device_rank - my_submatrix = muon_update_newton_schulz( - U_selected[device_rank], - newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) - - # All-gather: collect orthogonalized submatrices from all ranks - U_ortho = [torch.empty_like(u) for u in U_selected] - work = dist.all_gather( - U_ortho, my_submatrix.contiguous(), group=process_group, async_op=True - ) - yield - work.wait() - - # ------------------------------------------------------------------------- - # Single GPU path - # ------------------------------------------------------------------------- - else: - assert len(U_selected) == 1 - U_ortho = [ - muon_update_newton_schulz( - U_selected[0], newton_schulz_func, flatten=flatten, epsilon=epsilon - ) - ] - - # Step 3: Compute adjusted learning rate (based on full/global matrix shape) - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) - else: - raise ValueError(f"Unknown adjust_lr: {adjust_lr}") - - # Step 4: Apply weight update to selected slices only - dion2_post_orthogonalize( - X=to_local(X), - U=U_ortho, - indices=indices_list, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - select_dim=select_dim, - ) - - -@torch.compile(fullgraph=True) -def dion2_pre_orthogonalize( - G: List[Tensor], - M: List[Tensor], - fraction: Tensor, - ef_decay: Tensor, - select_dim: int, -) -> Tuple[List[Tensor], List[Tensor]]: - """ - Update momentum with gradient and compute the input to orthogonalization. - More specifically, it does the following steps: - - updates the momentum with gradient - - computes the top-k indices to determine submatrices - - does in-place error-feedback decay on the selected submatrices - - output submatrices and indices - Inputs and outputs should be lists of regular Tensor, not DTensor. - This is a separate function for compatibility with torch.compile(). - """ - dtype = M[0].dtype - - # norm_dim is the dimension we compute norm over - # select_dim is the dimension we select submatrix from - num_select = M[0].size(select_dim) - norm_dim = -1 if select_dim == -2 else -2 - k = max(1, int(math.ceil(fraction * num_select))) - - G = [g.to(dtype=dtype) for g in G] - torch._foreach_add_(M, G) - - M_stacked = torch.stack(M, dim=0) - - # Compute L1 norm along norm_dim (sum of absolute values) - slice_norms = M_stacked.norm(p=1, dim=norm_dim) - - # Batched topk: indices shape (batch_size, k) - _, indices = torch.topk(slice_norms, k, dim=-1, sorted=False) - - # Batched gather for slice extraction - if select_dim == -2: - # Selecting rows - num_cols = M[0].size(-1) - indices_expanded = indices.unsqueeze(-1).expand(-1, -1, num_cols) - selected_stacked = torch.gather(M_stacked, dim=-2, index=indices_expanded) - else: - # Selecting cols - num_rows = M[0].size(-2) - indices_expanded = indices.unsqueeze(-2).expand(-1, num_rows, -1) - selected_stacked = torch.gather(M_stacked, dim=-1, index=indices_expanded) - - # Apply error feedback decay to selected slices in original M tensors - indices_list = list(indices.unbind(dim=0)) - for m, idx in zip(M, indices_list): - selected_slice = m.index_select(dim=select_dim, index=idx) - m.index_copy_(dim=select_dim, index=idx, source=selected_slice * ef_decay) - - # Convert to bf16 and unstack for communication - U_selected = list(selected_stacked.to(dtype=torch.bfloat16).unbind(dim=0)) - - return U_selected, indices_list - - -@torch.compile(fullgraph=True) -def dion2_post_orthogonalize( - X: List[Tensor], - U: List[Tensor], - indices: List[Tensor], - base_lr: Tensor, - adjusted_lr: Tensor, - weight_decay: Tensor, - select_dim: int, -): - """ - Apply weight decay and weight update after orthogonalization. - Inputs and outputs should be lists of regular Tensor, not DTensor. - This is a separate function for compatibility with torch.compile(). - """ - torch._foreach_mul_(X, 1 - base_lr * weight_decay) - - # Convert U to match parameter dtype - dtype = X[0].dtype - U = [u.to(dtype=dtype) for u in U] - # Apply weight update - neg_lr = -adjusted_lr - U_scaled = [neg_lr * u for u in U] - for x, u_scaled, idx in zip(X, U_scaled, indices): - x.index_add_(dim=select_dim, index=idx, source=u_scaled) - - -# A helper function to print selection chocie for each matrix -# It only prints once `verbose` is set True -_printed_configs: set = set() - - -def _print_selection_choice( - shape: torch.Size, - shard_dim: Optional[int], - select_dim: int, - ndim: int, -): - config_key = (tuple(shape), shard_dim, select_dim) - if config_key not in _printed_configs: - _printed_configs.add(config_key) - - num_rows, num_cols = shape[-2:] - select_info = "rows" if select_dim == -2 else "columns" - norm_info = "row norms" if select_dim == -2 else "col norms" - - if shard_dim is None: - mode = "DDP/Single-GPU" - shorter = "rows" if num_rows <= num_cols else "cols" - reason = f"shorter dim = {shorter} ({min(num_rows, num_cols)})" - else: - # Normalize shard_dim for display - normalized = shard_dim if shard_dim < 0 else shard_dim - ndim - if normalized == -2: - mode = "FSDP" - reason = f"row-sharded (shard_dim={shard_dim}→-2)" - elif normalized == -1: - mode = "FSDP" - reason = f"col-sharded (shard_dim={shard_dim}→-1)" - else: - mode = "FSDP batch-sharded" - shorter = "rows" if num_rows <= num_cols else "cols" - reason = f"shard_dim={shard_dim} (batch), shorter = {shorter}" - - print( - f"[Dion2] Shape {tuple(shape)}: {mode}, {reason} → " - f"select top-α {select_info} by {norm_info}" - ) diff --git a/train.py b/train.py index fe8a2aa..18b22f2 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,3 @@ -# train.py below - import argparse import math import os @@ -30,7 +28,6 @@ from dion import Muon from dion import MuonReference from dion import Dion2 -from dion import Dion2Old from dion import NorMuon @@ -67,7 +64,7 @@ class Hyperparameters: lr: float = 0.02 mu: float = 0.95 weight_decay: float = 0.01 - rank_fraction: float = 0.125 + ortho_fraction: float = 0.25 # Optimizer specific hyperparameters qr_method: str = "rcqr" @@ -132,12 +129,6 @@ def parse_cli_args(): default=None, help="Adjust learning rate method for Muon", ) - parser.add_argument( - "--inv_rank_fraction", - type=int, - default=None, - help="1/r rank fraction for Dion", - ) parser.add_argument( "--qr_method", type=str, default=None, choices=["qr", "cqr", "rcqr"] ) @@ -365,7 +356,7 @@ def init_optimizer( dion_mixed_precision_config = None if hp.optimizer == "dion": - print0(f"Dion rank fraction: {hp.rank_fraction}") + print0(f"Dion rank fraction: {hp.ortho_fraction}") print0(f"Dion mixed precision: {hp.mixed_precision}") print0(f"Compressed data-parallel gradient sync: {hp.replicate_mesh_grad_sync}") opt = Dion( @@ -374,7 +365,7 @@ def init_optimizer( outer_shard_mesh=outer_shard_mesh, inner_shard_mesh=inner_shard_mesh, replicate_mesh_grad_sync=hp.replicate_mesh_grad_sync, - rank_fraction=hp.rank_fraction, + rank_fraction=hp.ortho_fraction, lr=hp.lr, mu=hp.mu, weight_decay=hp.weight_decay, @@ -385,7 +376,7 @@ def init_optimizer( ) elif hp.optimizer == "dion_reference": - print0(f"Dion rank fraction: {hp.rank_fraction}") + print0(f"Dion rank fraction: {hp.ortho_fraction}") print0(f"Dion QR method: {hp.qr_method}") print0(f"Dion mixed precision: {hp.mixed_precision}") print0(f"Compressed data-parallel gradient sync: {hp.replicate_mesh_grad_sync}") @@ -395,7 +386,7 @@ def init_optimizer( outer_shard_mesh=outer_shard_mesh, inner_shard_mesh=inner_shard_mesh, replicate_mesh_grad_sync=hp.replicate_mesh_grad_sync, - rank_fraction=hp.rank_fraction, + rank_fraction=hp.ortho_fraction, lr=hp.lr, mu=hp.mu, weight_decay=hp.weight_decay, @@ -433,9 +424,9 @@ def init_optimizer( ) elif hp.optimizer == "dion2": if device_mesh is not None: - # Ensure that we have a supported device mesh configuration for dion2 + # Ensure that we have a supported device mesh configuration for Dion2 if inner_shard_mesh is not None and inner_shard_mesh.size() > 1: - raise ValueError("Tensor parallel is not supported by dion2.") + raise ValueError("Tensor parallel is not supported by Dion2.") distributed_mesh = ( outer_shard_mesh if outer_shard_mesh.size() > 1 else replicate_mesh ) @@ -451,38 +442,12 @@ def init_optimizer( param_groups, distributed_mesh=distributed_mesh, lr=hp.lr, - fraction=hp.rank_fraction, - ef_decay=hp.mu, - weight_decay=hp.weight_decay, - adjust_lr=hp.adjust_lr, - use_triton=(not cli_args.no_triton), - verbose=hp.verbose, - ) - elif hp.optimizer == "dion2old": - if device_mesh is not None: - # Ensure that we have a supported device mesh configuration for dion2 - if inner_shard_mesh is not None and inner_shard_mesh.size() > 1: - raise ValueError("Tensor parallel is not supported by dion2.") - distributed_mesh = ( - outer_shard_mesh if outer_shard_mesh.size() > 1 else replicate_mesh - ) - comm_method = "all-to-all" if outer_shard_mesh.size() > 1 else "all-gather" - else: - assert ddp_model is not None - distributed_mesh = ddp_model.process_group # using ProcessGroup for DDP - comm_method = "all-gather" - print0(f"LR adjust method: {hp.adjust_lr}") - print0(f"Triton Newton-Schulz kernels: {not cli_args.no_triton}") - print0(f"Distributed Dion2Old using: {comm_method}") - opt = Dion2Old( - param_groups, - distributed_mesh=distributed_mesh, - lr=hp.lr, - fraction=hp.rank_fraction, + fraction=hp.ortho_fraction, ef_decay=hp.mu, weight_decay=hp.weight_decay, adjust_lr=hp.adjust_lr, use_triton=(not cli_args.no_triton), + verbose=hp.verbose, ) elif hp.optimizer == "normuon": if device_mesh is not None: @@ -514,13 +479,13 @@ def init_optimizer( elif hp.optimizer == "dion_simple": assert device_mesh is None, f"{hp.optimizer} does not support device mesh" - print0(f"Dion rank fraction: {hp.rank_fraction}") + print0(f"Dion rank fraction: {hp.ortho_fraction}") opt = DionSimple( param_groups, lr=hp.lr, mu=hp.mu, weight_decay=hp.weight_decay, - rank=round(hp.rank_fraction * hp.model_dim), + rank=round(hp.ortho_fraction * hp.model_dim), mixed_precision_config=dion_mixed_precision_config, ) @@ -675,9 +640,6 @@ def main(): hp = Hyperparameters() hp = override_args_from_cli(hp, cli_args) - if cli_args.inv_rank_fraction: - hp.rank_fraction = 1.0 / cli_args.inv_rank_fraction - if hp.checkpoint_freq > 0: if not hp.checkpoint_dir: raise ValueError("Must specify --checkpoint_dir to save checkpoints") @@ -848,7 +810,7 @@ def get_lr(it): # Create a name to identify this run run_name = f"({hp.optimizer}+{hp.scalar_opt})" if "dion" in hp.optimizer or "dion2" in hp.optimizer: - run_name += f"frac={hp.rank_fraction}" + run_name += f"frac={hp.ortho_fraction}" if cli_args.dp_size is not None: run_name += f"_dp={cli_args.dp_size}_fs={cli_args.fs_size}_tp={cli_args.tp_size}_gradsync={cli_args.replicate_mesh_grad_sync}" if cli_args.wandb_job_name: From 843e4c3161c557fe21e83d0a634f801cc3675863 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 05:40:22 -0800 Subject: [PATCH 06/14] train.py update, readme update --- README.md | 64 +++++++++++++++++++++++++-------------- configs/dion2_160m.yaml | 5 ++- configs/dion_160m.yaml | 2 +- configs/muon_160m.yaml | 3 +- configs/normuon_160m.yaml | 3 +- train.py | 25 +++++++++------ 6 files changed, 62 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 8adf4ec..c0be4e5 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,10 @@ This repository provides efficient implementations of orthonormal optimizers for distributed ML training. You can find the following optimizers: * [Muon](https://kellerjordan.github.io/posts/muon/) -* [Dion](https://arxiv.org/pdf/2504.05295) -* Dion2 +* Dion2 and [Dion](https://arxiv.org/pdf/2504.05295) Dion is a legacy optimizer; we recommend using Dion2) * [NorMuon](https://arxiv.org/abs/2510.05491) + ## Table of Contents
Show/Hide @@ -52,7 +52,7 @@ pip install git+https://github.com/microsoft/dion.git Then in your code, you can use: ```python -from dion import Dion, Dion2, Muon, NorMuon +from dion import Dion2, Muon, NorMuon, Dion ``` Please carefully go through this readme for detailed instructions on using our optimizers. There are major differences compared to PyTorch built-in optimizers, such as `Adam`/`AdamW`. @@ -73,41 +73,61 @@ python data/cached_fineweb10B.py 30 ### Distributed Data Parallel (DDP) Training -To train a GPT-small model using Dion2 with 8 GPUs (adjust as needed for your setup): +To train a GPT-small model using Dion2 with 4 GPUs (adjust as needed for your setup): ```bash -torchrun --standalone --nproc_per_node=8 train.py --config configs/dion_160m.yaml +torchrun --standalone --nproc_per_node=4 train.py --config configs/dion2_160m.yaml ``` This will launch Distributed Data Parallel (DDP) training. -### Advanced FSDP / TP / Hybrid Sharded Training +### Distributed Training: FSDP / TP / Hybrid Sharding -To enable more advanced distributed strategies such as Fully Sharded Data Parallel (FSDP) and Tensor Parallelism (TP), you can specify the configuration in the `dion_160m.yaml` file: +#### Fully Sharded Data Parallel (FSDP) -```yaml -# Example of sharding configuration -dp_size: 2 # data‐parallel size -fs_size: 2 # FSDP size -tp_size: 2 # tensor‐parallel size +To enable FSDP, specify the FSDP group size using `--fs_size`: +```bash +torchrun --standalone --nproc_per_node=4 train.py \ + --config configs/dion2_160m.yaml \ + --fs_size 4 ``` -This example sets up a hybrid configuration with DDP × FSDP × TP = 2 × 2 × 2. +This configuration trains a GPT-small model using Dion2 with FSDP sharding across all 4 GPUs (a single FSDP group of size 4). -Alternatively, you can override these values directly from the command line: +#### Hybrid Sharded Data Parallel (HSDP) +To use Hybrid Sharded Data Parallel, where multiple FSDP groups are replicated using Data Parallel (DP), set `--fs_size` smaller than the total number of GPUs and specify the data parallel dimension via `--dp_size`: ```bash -torchrun --standalone --nproc_per_node=8 train.py --config configs/dion_160m.yaml \ - --dp_size 2 --fs_size 2 --tp_size 2 +torchrun --standalone --nproc_per_node=4 train.py \ + --config configs/dion2_160m.yaml \ + --fs_size 2 \ + --dp_size 2 ``` -All three values must be explicitly given, but a size may be set to `1` to omit a parallelism dimension. For instance, for FSDP over 8 devices, you can either configure from `.yaml` as: +This configuration creates: +- **2 FSDP groups**, each spanning 2 GPUs +- **2-way data parallelism** across the FSDP groups +- **Total**: 4 GPUs with 2-way FSDP × 2-way DP + +#### Tensor Parallelism (TP) + +**Note**: Currently, only Dion (our legacy implementation) supports Tensor Parallelism. -```yaml -# Example of pure FSDP configuration -dp_size: 1 # data‐parallel size -fs_size: 8 # FSDP size -tp_size: 1 # tensor‐parallel size +You can combine all three parallelism strategies (DP × FSDP × TP). For example, a 2 × 2 × 2 configuration across 8 GPUs: +```bash +torchrun --standalone --nproc_per_node=8 train.py \ + --config configs/dion_160m.yaml \ + --dp_size 2 \ + --fs_size 2 \ + --tp_size 2 ``` +This configuration creates: +- **2-way data parallelism** (outer replication) +- **2-way FSDP** +- **2-way tensor parallelism** +- **Total**: 8 GPUs with 2-way DP × 2-way FSDP × 2-way TP + +**General rule**: The product `dp_size × fs_size × tp_size` must equal `nproc_per_node`. Any unspecified dimension defaults to 1. + ## Introduction diff --git a/configs/dion2_160m.yaml b/configs/dion2_160m.yaml index 55ff678..17cfdc5 100644 --- a/configs/dion2_160m.yaml +++ b/configs/dion2_160m.yaml @@ -6,7 +6,7 @@ n_head: 6 sequence_length: 1024 # — Batching & Training — -batch_size: 1024 +batch_size: 128 device_batch_size: 32 num_iterations: 3000 @@ -30,8 +30,7 @@ no_triton: false # — Distributed training — dp_size: null # data‐parallel size -fs_size: null # FSDP size -tp_size: null # DO NOT USE TP for Dion2 +fs_size: null # FSDP size # — Optimizer & Hyperparameters — optimizer: dion2 diff --git a/configs/dion_160m.yaml b/configs/dion_160m.yaml index 6759019..68b98d1 100644 --- a/configs/dion_160m.yaml +++ b/configs/dion_160m.yaml @@ -38,6 +38,6 @@ optimizer: dion scalar_opt: lion mu: 0.95 weight_decay: 0.01 -ortho_fraction: 0.125 +ortho_fraction: 0.5 lr: 0.02 mixed_precision: true diff --git a/configs/muon_160m.yaml b/configs/muon_160m.yaml index a39764e..466a20e 100644 --- a/configs/muon_160m.yaml +++ b/configs/muon_160m.yaml @@ -25,8 +25,7 @@ no_wandb: false # — Distributed training — dp_size: null # data‐parallel size -fs_size: null # FSDP size -tp_size: null # DO NOT USE TP for Muon +fs_size: null # FSDP size # — Miscellaneous flags — debug: false diff --git a/configs/normuon_160m.yaml b/configs/normuon_160m.yaml index 9607d0b..e35dba3 100644 --- a/configs/normuon_160m.yaml +++ b/configs/normuon_160m.yaml @@ -25,8 +25,7 @@ no_wandb: false # — Distributed training — dp_size: null # data‐parallel size -fs_size: null # FSDP size -tp_size: null # DO NOT USE TP for NorMuon +fs_size: null # FSDP size # — Miscellaneous flags — debug: false diff --git a/train.py b/train.py index 18b22f2..f6ab86f 100644 --- a/train.py +++ b/train.py @@ -75,7 +75,7 @@ class Hyperparameters: adjust_lr: str = "spectral_norm" # for Muon only # For printing out selection choice in Dion2 - verbose: bool = False + verbose: bool = True # Helper function to only print on global rank 0 @@ -267,11 +267,12 @@ def init_distributed(dp_size, fs_size, tp_size) -> Optional[DeviceMesh]: print0(f"World size: {world_size}") else: - # Use device mesh for distributed training - # All mesh dimensions must be specified - assert all( - d is not None for d in mesh_dims - ), f"All mesh dimensions (dp_size, fs_size, tp_size) must be specified, but got ({dp_size}, {fs_size}, {tp_size})" + # Use device mesh for distributed training + # Fill None values with 1 + dp_size = dp_size if dp_size is not None else 1 + fs_size = fs_size if fs_size is not None else 1 + tp_size = tp_size if tp_size is not None else 1 + # Check if we have the right number of GPUs total_gpus = dp_size * fs_size * tp_size @@ -808,11 +809,15 @@ def get_lr(it): # --- Logging initialization --- # Load hyperparameters and update with CLI arguments # Create a name to identify this run - run_name = f"({hp.optimizer}+{hp.scalar_opt})" + optimizer_name = hp.optimizer if "dion" in hp.optimizer or "dion2" in hp.optimizer: - run_name += f"frac={hp.ortho_fraction}" - if cli_args.dp_size is not None: - run_name += f"_dp={cli_args.dp_size}_fs={cli_args.fs_size}_tp={cli_args.tp_size}_gradsync={cli_args.replicate_mesh_grad_sync}" + optimizer_name = f"{hp.ortho_fraction}-{hp.optimizer}" + + run_name = f"({optimizer_name}+{hp.scalar_opt})" + + if device_mesh is not None: + dp, fs, tp = device_mesh.size(0), device_mesh.size(1), device_mesh.size(2) + run_name += f"_(dp={dp}, fs={fs}, tp={tp})" if cli_args.wandb_job_name: run_name += f"_{cli_args.wandb_job_name}" From 3f8a1b30f50547f8785467a7fab8672fcaff6835 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 05:42:05 -0800 Subject: [PATCH 07/14] readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c0be4e5..bed2967 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ This repository provides efficient implementations of orthonormal optimizers for distributed ML training. You can find the following optimizers: * [Muon](https://kellerjordan.github.io/posts/muon/) -* Dion2 and [Dion](https://arxiv.org/pdf/2504.05295) Dion is a legacy optimizer; we recommend using Dion2) +* Dion2 and [Dion](https://arxiv.org/pdf/2504.05295) (Dion is a legacy optimizer; we recommend using Dion2) * [NorMuon](https://arxiv.org/abs/2510.05491) @@ -126,7 +126,7 @@ This configuration creates: - **2-way tensor parallelism** - **Total**: 8 GPUs with 2-way DP × 2-way FSDP × 2-way TP -**General rule**: The product `dp_size × fs_size × tp_size` must equal `nproc_per_node`. Any unspecified dimension defaults to 1. +**General rule**: The product `dp_size × fs_size × tp_size` must equal `world_size`. Any unspecified dimension defaults to 1. ## Introduction From c54135f21fe424830203929451473e7320912e93 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 06:32:19 -0800 Subject: [PATCH 08/14] readme --- README.md | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index bed2967..4eb44c3 100644 --- a/README.md +++ b/README.md @@ -131,16 +131,16 @@ This configuration creates: ## Introduction -Optimization algorithms are essential to training neural networks, converting gradients into model weight updates to minimize loss. For many years, the state-of-the-art method has been [Adam](https://arxiv.org/abs/1412.6980)/[AdamW](https://arxiv.org/abs/1711.05101). However, recent work has shown that **orthonormal matrix optimizers** can significantly accelerate model convergence. Check out blog posts by [Jeremy Bernstein](https://jeremybernste.in/writing/deriving-muon) and [Laker Newhouse](https://www.lakernewhouse.com/writing/muon-1) for more details. +Optimization algorithms are essential to training neural networks, converting gradients into model weight updates to minimize loss. For many years, the method of choice has been [Adam](https://arxiv.org/abs/1412.6980)/[AdamW](https://arxiv.org/abs/1711.05101). However, recent work has shown that **orthonormal optimizers** can significantly accelerate model convergence. Check out blog posts by [Jeremy Bernstein](https://jeremybernste.in/writing/deriving-muon) and [Laker Newhouse](https://www.lakernewhouse.com/writing/muon-1) for more details. -The practical effectiveness of orthonormal updates was first demonstrated by [Muon](https://kellerjordan.github.io/posts/muon/) in the [NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt), and has since been validated at scale by models such as [Kimi K2](https://arxiv.org/abs/2507.20534) and [GLM-4.5](https://z.ai/blog/glm-4.5). Muon implements orthonormalization via *Newton-Schulz iterations*, which relies on repeated matrix-matrix multiplications. However, large-scale training relies on model sharding, where weight matrices and optimizer states are distributed across multiple processes. As discussed by [Essential AI](https://www.essential.ai/blog/infra), orthonormalizing a sharded matrix with Newton-Schulz iterations involves the communication-intensive procedure of reconstructing the full matrices from their individual shards. +The practical effectiveness of orthonormal optimizers was first demonstrated by [Muon](https://kellerjordan.github.io/posts/muon/) in the [NanoGPT speedrun](https://github.com/KellerJordan/modded-nanogpt), and has since been validated at scale by models such as [Kimi K2](https://arxiv.org/abs/2507.20534) and [GLM-4.5](https://z.ai/blog/glm-4.5). Muon implements orthonormalization via *Newton-Schulz iterations*, which relies on repeated matrix-matrix multiplications. However, large-scale training relies on model sharding, where weight matrices and optimizer states are distributed across multiple processes. As discussed by [Essential AI](https://www.essential.ai/blog/infra), orthonormalizing a sharded matrix with Newton-Schulz iterations involves the communication-intensive procedure of reconstructing the full matrices from their individual shards. -**Dion/Dion2** are our methods for building a **scalable, communication-efficient** optimizer. Like Muon, it computes orthonormal weight updates and has the same benefits of faster model convergence. The key difference is that Dion/Dion2 **shrink the matrix before orthonormalization**. Dion uses power iteration to compute a low-rank approximation, while Dion2 applies a simple submatrix-selection procedure. To reduce information loss, both methods include an error-feedback mechanism that tracks the discrepancy between the original matrix and its compressed approximation. +**Dion/Dion2** are our methods for building a **scalable, communication-efficient** optimizer. Like Muon, they compute matrix weight updates based on matrix orthonormalization and share similar practical benefits. The key difference is that Dion and Dion2 **shirnk the matrix before orthonormalization**, reducing both computational and communication costs. Dion uses power iteration to compute a low-rank approximation, while Dion2 applies a simple submatrix-selection procedure. To reduce information loss, both methods include an error-feedback mechanism that tracks the discrepancy between the original matrix and its compressed approximation. ## Optimizers -Our main implementations of Dion (`dion.py`) and Muon (`muon.py`) support the following parallelization techniques: +Our current implementations support the following parallelization techniques: | Parallelization | Dion | Dion2 | Muon | NorMuon | |--------------------|------|-------|------|---------| @@ -149,13 +149,13 @@ Our main implementations of Dion (`dion.py`) and Muon (`muon.py`) support the fo | PyTorch FSDP2 | Yes | Yes | Yes | Yes | | PyTorch FSDP2 + TP | Yes | No | No | No | -For faster performance, both of these optimizers will process parameters in batches and interleave multiple batches to overlap compute with communication. +For faster performance, these optimizers will process parameters in batches and interleave multiple batches to overlap compute with communication. We include optimizer implementations in the `dion/` directory of this repo. * `dion.py`: High-performance version of Dion. Depending on how each batch of matrices is sharded, we select the best communication patterns to compute Dion's orthonormal update. All-reduce operations may be split into reduce-scatter and all-gather across the batch dimension to more efficiently distribute work and avoid redundant computation. * `muon.py`: High-performance version of Muon. For sharded matrices, all-to-all communication is used to simultaneously unshard and distribute a batch of matrices. For replicated matrices, Muon will distribute work across all devices and all-gather final results. -* `dion2.py`: A preliminary implementation of Dion2, which uses a similar all-to-all communication pattern to distribute orthonormalization. Only an $\alpha$-fraction of the momentum matrix is orthonormalized, leaving room for additional communication optimizations. +* **`dion2.py`**: High-performance implementation of Dion2, using a similar all-to-all communication pattern for distributed orthonormalization. Only an α-fraction of the momentum matrix is communicated and orthonormalized, significantly reducing both communication overhead and computation cost. * `normuon.py`: A variant of the Muon optimizer that introduces neuron-wise normalization to improve stability and convergence efficiency, modified to take similar arguments as `muon.py`. See [the paper](https://arxiv.org/abs/2510.05491) for more details. We also provide some reference implementations: @@ -184,11 +184,11 @@ We summarize the above in this table. Let `d_in` be the input dimension of the u | Type | Example parameters | Optimizer `algorithm` | Learning rate `lr` | |---------------|---------------------------------------------|-----------------------|------------------------| -| Weight matrix | `nn.Linear.weight` | `"dion"` / `"muon"` | `lr` | -| Bias vector | `nn.Linear.bias` | `"lion"` / `"adamw"` | `lr` | -| Normalization | `nn.LayerNorm.weight`, `nn.LayerNorm.bias` | `"lion"` / `"adamw"` | `lr` | -| Embedding | `nn.Embedding.weight` | `"lion"` / `"adamw"` | `lr` | -| Unembedding | `nn.Linear.weight` (must identify manually) | `"lion"` / `"adamw"` | `lr / math.sqrt(d_in)` | +| Weight matrix | `nn.Linear.weight` | `"dion2"` / `"muon"` | `lr` | +| Bias vector | `nn.Linear.bias` | `"adamw"` / `"lion"` | `lr` | +| Normalization | `nn.LayerNorm.weight`, `nn.LayerNorm.bias` | `"adamw"` / `"lion"` | `lr` | +| Embedding | `nn.Embedding.weight` | `"adamw"` / `"lion"` | `lr` | +| Unembedding | `nn.Linear.weight` (must identify manually) | `"adamw"` / `"lion"` | `lr / math.sqrt(d_in)` | We emphasize again that **particular care** needs to be taken with **embedding and unembedding layers**. They must be isolated from ordinary matrix parameters, and the unembedding layer furthermore should use a scaled learning rate. Merely checking the dimensions of a parameter (such as `if p.ndim == 2`) or the type of the module (such as `if isinstance(module, nn.Linear)`) **is not sufficient** to identify these special parameters. This is why we require manual parameter group creation. @@ -214,9 +214,9 @@ lm_head_params= list(model.lm_head.parameters()) param_groups = [ dict(params=matrix_params), # will default to "dion" algorithm - dict(params=vector_params, algorithm="lion"), - dict(params=embed_params, algorithm="lion"), - dict(params=lm_head_params, algorithm="lion", lr=lr / math.sqrt(model_dim)) + dict(params=vector_params, algorithm="adamw"), + dict(params=embed_params, algorithm="adamw"), + dict(params=lm_head_params, algorithm="adamw", lr=lr / math.sqrt(model_dim)) ] optimizer = Dion( @@ -231,16 +231,17 @@ Additional hyperparameters may be specified on a per-parameter-group basis to ov ```python param_groups = [ dict(params=matrix_params), - dict(params=vector_params, algorithm="lion"), - dict(params=embed_params, algorithm="lion", weight_decay=0), - dict(params=lm_head_params, algorithm="lion", lr=lr / math.sqrt(model_dim), weight_decay=0) + dict(params=vector_params, algorithm="adamw"), + dict(params=embed_params, algorithm="adamw", weight_decay=0), + dict(params=lm_head_params, algorithm="adamw", lr=lr / math.sqrt(model_dim), weight_decay=0) ] ``` ## Distributed Training Configuration -In order for our efficient distributed optimizers to work, they must know about the parallelization scheme for training your model. This is done by passing in `DeviceMesh` objects when constructing the optimizer. +For our efficient distributed optimizers to work correctly, they need information about the model's parallelization scheme. This is provided by passing `DeviceMesh` objects during optimizer construction. +We demonstrate this for Dion first, since it has the most comprehensive parallelism support in our current implementation. ### Device Mesh for Dion @@ -289,7 +290,7 @@ optimizer = Dion( ) ``` -### Device Mesh for Muon +### Device Mesh for Dion2, Muon and Others Muon uses different device mesh arguments from Dion. From 2dc80322a2374a6389fed14ecb1c381e71475fdc Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 09:35:53 -0800 Subject: [PATCH 09/14] read me --- README.md | 85 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 4eb44c3..279f5bd 100644 --- a/README.md +++ b/README.md @@ -237,34 +237,38 @@ param_groups = [ ] ``` - ## Distributed Training Configuration For our efficient distributed optimizers to work correctly, they need information about the model's parallelization scheme. This is provided by passing `DeviceMesh` objects during optimizer construction. -We demonstrate this for Dion first, since it has the most comprehensive parallelism support in our current implementation. - -### Device Mesh for Dion -Dion supports up to two sharded mesh dimensions and any number of data-parallel replicated mesh dimensions. The sharded meshes are referred to as `outer_shard_mesh` and `inner_shard_mesh`. Dion's internal optimizer states can be sharded over both meshes. During the update computation, Dion will orthonormalize a low-rank matrix that is replicated across `outer_shard_mesh`, but always remains sharded across `inner_shard_mesh`. Thus, the `inner_shard_mesh` is more communication-intensive and works best with intra-node tensor parallelism. Both sharding meshes must be one-dimensional. +### 1D Sharding Configuration (Dion2, Muon, NorMuon) -Unused meshes may be omitted or given as `None`. If only one sharding dimension is used (e.g. only FSDP without TP), we recommend providing it as the `outer_shard_mesh`. Dion will execute a faster single-device orthonormalization routine in this case, since the input matrix to be orthonormalized will not be sharded. +Most optimizers in this codebase (Dion2, Muon, NorMuon) currently support only 1D sharding. They accept a single 1D device mesh via the `distributed_mesh` argument and adapt their behavior based on how this mesh is used: +- **If the mesh is used for parameter sharding**: The optimizer efficiently unshards parameters using all-to-all communication +- **If the mesh is not used for sharding**: The optimizer distributes work across devices and all-gathers the final results + +For a hybrid sharded data parallel (HSDP) configuration with both replicated and sharded dimensions, pass only the sharded sub-mesh to the optimizer: ```python -# Example with a 3D mesh mesh = init_device_mesh( device_type="cuda", - mesh_shape=(dp_size, fs_size, tp_size), - mesh_dim_names=("dp", "fs", "tp") + mesh_shape=(replicate_size, shard_size), + mesh_dim_names=("replicate", "shard"), ) -optimizer = Dion( +# Apply HSDP with 2D device mesh +# Parameters are sharded across the 1st dim and replicated across the 0th dim +# https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html +fully_shard(model, mesh=mesh) + +# Pass only the sharded dimension to the optimizer +optimizer = Dion2( # or Muon or NorMuon param_groups, - replicate_mesh = mesh["dp"], # Replicated data parallel - outer_shard_mesh = mesh["fs"], # Sharded data parallel - inner_shard_mesh = mesh["tp"], # Tensor parallel + distributed_mesh=mesh["shard"], # 1D sub-mesh (sharded dimension only) ... ) ``` + ### Flattened Meshes @@ -273,8 +277,8 @@ When more advanced parallelism strategies are used (such as context parallel or ```python mesh = init_device_mesh( device_type="cuda", - mesh_shape=(dp_size, cp_size, tp_size), - mesh_dim_names=("dp", "cp", "tp") + mesh_shape=(dp_size, cp_size), + mesh_dim_names=("dp", "cp") ) # FSDP sharding applied across combined DP and CP meshes @@ -290,31 +294,6 @@ optimizer = Dion( ) ``` -### Device Mesh for Dion2, Muon and Others - -Muon uses different device mesh arguments from Dion. - -Our implementation of Muon takes a single 1D device mesh as a generic `distributed_mesh` argument. If this mesh is used for sharding parameters, Muon will efficiently perform unsharding using all-to-all. If this mesh is not used for sharding, Muon will distribute work across this mesh and all-gather the final results. - -2D sharding is not supported by Muon---use Dion instead. For hybrid-sharded data parallel, with a replicated mesh dimension and a sharded dimension, pass only the sharded sub-mesh to Muon. - -```python -mesh = init_device_mesh( - device_type="cuda", - mesh_shape=(replicate_size, shard_size), - mesh_dim_names=("replicate", "shard"), -) - -# Hybrid sharded data parallel with 2D device mesh -fully_shard(model, mesh=mesh) - -optimizer = Muon( - param_groups, - distributed_mesh = mesh["shard"], # 1D sub-mesh - ... -) -``` - ### Usage with DDP ProcessGroup Training with DistributedDataParallel (DDP) is also supported. DDP uses PyTorch `ProcessGroup` instead of `DeviceMesh`, which is stored in the DDP-wrapped model's `process_group` field. Providing this to the optimizer will allow it to efficiently distribute work across all GPUs. If no `process_group` is provided, the optimizer will run in single-GPU mode, and every device in the DDP world will redundantly perform the same work. @@ -336,6 +315,32 @@ optimizer = Muon( ``` + +### Device Mesh for Dion + +We demonstrate this for Dion first, since it has the most comprehensive parallelism support in our current implementation. +Dion supports up to two sharded mesh dimensions and any number of data-parallel replicated mesh dimensions. The sharded meshes are referred to as `outer_shard_mesh` and `inner_shard_mesh`. Dion's internal optimizer states can be sharded over both meshes. During the update computation, Dion will orthonormalize a low-rank matrix that is replicated across `outer_shard_mesh`, but always remains sharded across `inner_shard_mesh`. Thus, the `inner_shard_mesh` is more communication-intensive and works best with intra-node tensor parallelism. Both sharding meshes must be one-dimensional. + +Unused meshes may be omitted or given as `None`. If only one sharding dimension is used (e.g. only FSDP without TP), we recommend providing it as the `outer_shard_mesh`. Dion will execute a faster single-device orthonormalization routine in this case, since the input matrix to be orthonormalized will not be sharded. + +```python +# Example with a 3D mesh +mesh = init_device_mesh( + device_type="cuda", + mesh_shape=(dp_size, fs_size, tp_size), + mesh_dim_names=("dp", "fs", "tp") +) + +optimizer = Dion( + param_groups, + replicate_mesh = mesh["dp"], # Replicated data parallel + outer_shard_mesh = mesh["fs"], # Sharded data parallel + inner_shard_mesh = mesh["tp"], # Tensor parallel + ... +) +``` + + ## Compressed Data-Parallel Gradient Sync Dion is capable of *skipping the usual full-gradient all-reduce* by only synchronizing low-rank matrices instead. Depending on the rank fraction used, we can greatly compress the amount of communication needed while producing the exact same end result (up to numerical precision). This technique originates from PowerSGD---see [Vogels et al., 2019](https://arxiv.org/abs/1905.13727) for more details. From 8ff2aadf11f4498d90d40863d98f22ca93016ad4 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 10:10:52 -0800 Subject: [PATCH 10/14] readme update --- README.md | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 279f5bd..7f746aa 100644 --- a/README.md +++ b/README.md @@ -270,9 +270,9 @@ optimizer = Dion2( # or Muon or NorMuon ``` -### Flattened Meshes +#### Flattened Meshes -When more advanced parallelism strategies are used (such as context parallel or expert parallel), it is common for multiple mesh dimensions to be "flattened" into a 1D sub-mesh for sharding. In this scenario, the flattened mesh needs to be given to Dion. +When more advanced parallelism strategies are used (such as context parallel or expert parallel), it is common for multiple mesh dimensions to be "flattened" into a 1D sub-mesh for sharding. In this scenario, the flattened mesh needs to be given to the optimizer. ```python mesh = init_device_mesh( @@ -285,29 +285,21 @@ mesh = init_device_mesh( fs_mesh = mesh["dp", "cp"]._flatten() fully_shard(model, mesh=fs_mesh) -optimizer = Dion( +optimizer = Dion2( # or Muon or NorMuon param_groups, - replicate_mesh = None, # No replicated data parallel used - outer_shard_mesh = fs_mesh, # Sharded data parallel across flattened mesh - inner_shard_mesh = mesh["tp"], # Tensor parallel + distributed_mesh = fs_mesh, # Sharded data parallel across flattened mesh ... ) ``` -### Usage with DDP ProcessGroup +#### Usage with DDP ProcessGroup Training with DistributedDataParallel (DDP) is also supported. DDP uses PyTorch `ProcessGroup` instead of `DeviceMesh`, which is stored in the DDP-wrapped model's `process_group` field. Providing this to the optimizer will allow it to efficiently distribute work across all GPUs. If no `process_group` is provided, the optimizer will run in single-GPU mode, and every device in the DDP world will redundantly perform the same work. ```python ddp_model = DistributedDataParallel(model, ...) -optimizer = Dion( - param_groups, - replicated_mesh=ddp_model.process_group, - ... -) -# - or - -optimizer = Muon( +optimizer = Dion2( # or Muon or NorMuon param_groups, distributed_mesh=ddp_model.process_group, ... From aee72510d2f72a01b32fdaa3bcc781ffd073823d Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 10:29:18 -0800 Subject: [PATCH 11/14] readme reorganize --- README.md | 60 ++++++++++++++++++++++++------------------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 7f746aa..99916ef 100644 --- a/README.md +++ b/README.md @@ -308,9 +308,8 @@ optimizer = Dion2( # or Muon or NorMuon -### Device Mesh for Dion - -We demonstrate this for Dion first, since it has the most comprehensive parallelism support in our current implementation. +### 2D-Sharding Support for Dion + Dion supports up to two sharded mesh dimensions and any number of data-parallel replicated mesh dimensions. The sharded meshes are referred to as `outer_shard_mesh` and `inner_shard_mesh`. Dion's internal optimizer states can be sharded over both meshes. During the update computation, Dion will orthonormalize a low-rank matrix that is replicated across `outer_shard_mesh`, but always remains sharded across `inner_shard_mesh`. Thus, the `inner_shard_mesh` is more communication-intensive and works best with intra-node tensor parallelism. Both sharding meshes must be one-dimensional. Unused meshes may be omitted or given as `None`. If only one sharding dimension is used (e.g. only FSDP without TP), we recommend providing it as the `outer_shard_mesh`. Dion will execute a faster single-device orthonormalization routine in this case, since the input matrix to be orthonormalized will not be sharded. @@ -333,7 +332,21 @@ optimizer = Dion( ``` -## Compressed Data-Parallel Gradient Sync + + +## Best Practices + +* **Dion rank fraction:** The most important Dion-specific hyperparameter is the *rank fraction*, which controls the amount of low-rank compression. Setting `rank_fraction=1.0` resulting in full-rank updates without any compression, similar to Muon. Empirically, it appears that larger models are more tolerant of low-rank compression. At 3B parameters, `rank_fraction=0.25` (1/4 rank) achieves nearly equivalent performance as full-rank, and we expect that 1/8, 1/16, and perhaps lower rank fractions will work well at 10B+ scale. +* **Lion vs. AdamW:** We have found that Lion performs better than AdamW for optimizing scalar parameters when used with Dion/Muon for orthonormal matrix updates. +* **2D sharding:** If weights are sharded with both FSDP and TP, it is required that the sharding methods are applied to different matrix dimensions. The TP sharding dimension is controlled via `RowwiseParallel` and `ColwiseParallel`, but the FSDP sharding dimension needs to be manually specified when applied on top of TP. See `models/gpt_model.py` for an example of explicitly providing `fully_shard()` with per-parameter shard dimensions. Double-sharded matrices along the same dimension will raise an error in Dion. +* **Learning rate scaling:** Dion will automatically scale the provided learning rate by `sqrt(d_out / d_in)` for matrix parameters. Muon will apply the same scaling by default, but also supports the `0.2 * sqrt(max(d_in, d_out))` scale factor recommended by Moonshot AI. Our default scale factor is intended to induce a consistent change to activation vector values, which enables learning rate transfer across model size. See [Deriving Muon](https://jeremybernste.in/writing/deriving-muon) for more information. +* **Nesterov momentum:** In Muon, we set Nesterov momentum to `False` by default, as we observed better performance without it. Dion does not implement Nesterov momentum. + + +## Other Features + + +### Compressed Data-Parallel Gradient Sync for Dion Dion is capable of *skipping the usual full-gradient all-reduce* by only synchronizing low-rank matrices instead. Depending on the rank fraction used, we can greatly compress the amount of communication needed while producing the exact same end result (up to numerical precision). This technique originates from PowerSGD---see [Vogels et al., 2019](https://arxiv.org/abs/1905.13727) for more details. @@ -344,7 +357,7 @@ This feature is applicable across any replicated data-parallel axis for DDP and Note that `replicate_mesh_grad_sync=True` results in *decoupled momentum*. The optimizer's internal momentum states will diverge across data-parallel processes. (Model weight updates always remain identical.) Before saving a checkpoint, you must explicitly tell Dion to synchronize internal states. See the [Checkpointing](#checkpointing) section for more details. -### Usage with HSDP +#### Usage with HSDP Typically, hybrid sharding with `fully_shard()` uses a 2D device mesh. To use with Dion's compressed gradient synchronization, pass only the sharded sub-mesh to `fully_shard()`. @@ -357,7 +370,9 @@ Note that if we choose to disable Dion's compressed gradient synchronization, we | Dion syncs compressed states | 1D shard sub-mesh | `True` | Decoupled | Always synchronous | | FSDP syncs full gradients | 2D hybrid-shard mesh | `False` | Synchronous | Always synchronous | -### Example Code +#### Example Codes + +We provide example codes for compressed-DP sync under HSDP scenarios. ```python # ------------------------------------------------------------ @@ -389,8 +404,6 @@ opt = Dion( ) ``` -### Usage with DDP - To use compressed gradient synchronization with DDP, always run the model with the `no_sync()` context. ```python @@ -413,9 +426,9 @@ for data in dataloader: model.zero_grad() ``` -### Checkpointing +#### Checkpointing under Compressed-DP Sync -Dion requires synchronizing optimizer state before saving a checkpoint. Because of Dion's decoupled momentum, internal optimizer states will be different across the replicate mesh. Call the `synchronize_for_checkpoint()` function to explicitly perform an all-reduce of optimizer states. This ensures the consistency of distributed checkpoints, since typically each state will only be saved by one process along the replicated data-parallel mesh. This function will be a no-op if `replicate_mesh_grad_sync=False` or no replicate mesh is used. +Dion when `replicate_mesh_grad_sync = True` requires synchronizing optimizer state before saving a checkpoint. This is because of Dion's decoupled momentum, where internal optimizer states will be different across the replicate mesh. Call the `synchronize_for_checkpoint()` function to explicitly perform an all-reduce of optimizer states. This ensures the consistency of distributed checkpoints, since typically each state will only be saved by one process along the replicated data-parallel mesh. This function will be a no-op if `replicate_mesh_grad_sync=False` or no replicate mesh is used. If model parameters are `DTensor` type, the optimizer states will also be `DTensor`s. Checkpoints should be saved using [torch.distributed.checkpoint](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html). @@ -437,6 +450,8 @@ optimizer.step() model.zero_grad() # Call this before checkpointing +# This is only for Dion with `replicate_mesh_grad_sync=True` so +# For other optimizers, this is not required optimizer.synchronize_for_checkpoint() # Save a distributed checkpoint @@ -445,18 +460,6 @@ checkpoint = { "model": model_state_dict, "optimizer": opt_state_dict } dcp.save(checkpoint, ...) ``` - -## Best Practices - -* **Dion rank fraction:** The most important Dion-specific hyperparameter is the *rank fraction*, which controls the amount of low-rank compression. Setting `rank_fraction=1.0` resulting in full-rank updates without any compression, similar to Muon. Empirically, it appears that larger models are more tolerant of low-rank compression. At 3B parameters, `rank_fraction=0.25` (1/4 rank) achieves nearly equivalent performance as full-rank, and we expect that 1/8, 1/16, and perhaps lower rank fractions will work well at 10B+ scale. -* **Lion vs. AdamW:** We have found that Lion performs better than AdamW for optimizing scalar parameters when used with Dion/Muon for orthonormal matrix updates. -* **2D sharding:** If weights are sharded with both FSDP and TP, it is required that the sharding methods are applied to different matrix dimensions. The TP sharding dimension is controlled via `RowwiseParallel` and `ColwiseParallel`, but the FSDP sharding dimension needs to be manually specified when applied on top of TP. See `models/gpt_model.py` for an example of explicitly providing `fully_shard()` with per-parameter shard dimensions. Double-sharded matrices along the same dimension will raise an error in Dion. -* **Learning rate scaling:** Dion will automatically scale the provided learning rate by `sqrt(d_out / d_in)` for matrix parameters. Muon will apply the same scaling by default, but also supports the `0.2 * sqrt(max(d_in, d_out))` scale factor recommended by Moonshot AI. Our default scale factor is intended to induce a consistent change to activation vector values, which enables learning rate transfer across model size. See [Deriving Muon](https://jeremybernste.in/writing/deriving-muon) for more information. -* **Nesterov momentum:** In Muon, we set Nesterov momentum to `False` by default, as we observed better performance without it. Dion does not implement Nesterov momentum. - - -## Experimental Features - ### Mixed Precision Dion By default, Dion will initialize its optimizer states to use the same data type as the model's parameters. The `DionMixedPrecisionConfig` class may be used to specify custom data types. In preliminary experiments, we have found that using `torch.bfloat16` for Dion's optimizer states can reduce memory use and speed up computation with no impact on training stability. @@ -475,18 +478,7 @@ optimizer = Dion( ... ) ``` - -### Faster Dion for lower ranks - -After a few warmup iterations, the expensive QR decomposition can be replaced with the Cholesky QR (CQR) algorithm, leading to **2X** optimization step speedups. CQR is faster but less numerically stable. We have found that after some initial warmup period, the input matrix for orthogonalization becomes relatively well-conditioned. If Cholesky decomposition fails, we fall back to the standard QR decomposition procedure. - -To try out the CQR accelerated configuration: -```bash -torchrun --standalone --nproc_per_node=8 train.py --config configs/dion_efficient_160m.yaml -``` - -After the training you should be able to reproduce the second plot in [validation curves for GPT-small](https://microsoft-research.wandb.io/t-gmagakyan/dion-exp/reports/Validation-curves-for-GPT-small--VmlldzoxNjk5OA?accessToken=52e6z4d18yfkewz1bawlkmwc2m91al9ssa7rpwvnx1f1xa66j15lr7x315wj2kys). - + ### Triton Kernels for Muon Newton-Schulz Muon's Newton-Schulz iteration involves multiplying a matrix by its own transpose. The result is symmetric, so we can accelerate this computation by only computing half of the output and mirroring the result across the diagonal. We implemented this technique with Triton kernels in `optimizers/newton_schulz_triton.py`. From d3eeab34daa53b7b993a04f22afec12ec8663b20 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 10:34:08 -0800 Subject: [PATCH 12/14] readme clean up --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 99916ef..bc2d68e 100644 --- a/README.md +++ b/README.md @@ -426,7 +426,7 @@ for data in dataloader: model.zero_grad() ``` -#### Checkpointing under Compressed-DP Sync +### Checkpointing Dion when `replicate_mesh_grad_sync = True` requires synchronizing optimizer state before saving a checkpoint. This is because of Dion's decoupled momentum, where internal optimizer states will be different across the replicate mesh. Call the `synchronize_for_checkpoint()` function to explicitly perform an all-reduce of optimizer states. This ensures the consistency of distributed checkpoints, since typically each state will only be saved by one process along the replicated data-parallel mesh. This function will be a no-op if `replicate_mesh_grad_sync=False` or no replicate mesh is used. From 1b61ff5323c7a6e6bb3d0ae4a069914221427c37 Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Thu, 18 Dec 2025 11:07:57 -0800 Subject: [PATCH 13/14] black format --- configs/dion2_160m.yaml | 2 +- dion/__init__.py | 2 +- dion/dion2.py | 29 +++++++++++++++-------------- train.py | 7 +++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/configs/dion2_160m.yaml b/configs/dion2_160m.yaml index 17cfdc5..55b78d3 100644 --- a/configs/dion2_160m.yaml +++ b/configs/dion2_160m.yaml @@ -6,7 +6,7 @@ n_head: 6 sequence_length: 1024 # — Batching & Training — -batch_size: 128 +batch_size: 1024 device_batch_size: 32 num_iterations: 3000 diff --git a/dion/__init__.py b/dion/__init__.py index a3d18f4..34894e6 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -3,6 +3,6 @@ from .dion_simple import Dion as DionSimple from .dion_reference import Dion as DionReference from .muon import Muon -from .muon_reference import Muon as MuonReference +from .muon_reference import Muon as MuonReference from .dion2 import Dion2 from .normuon import NorMuon diff --git a/dion/dion2.py b/dion/dion2.py index 274fa7f..b682626 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -400,16 +400,16 @@ def dion2_update_batch_async( X: List[Tensor], # Model weights (modified in place) G: List[Tensor], # Gradient M: List[Tensor], # Momentum buffer (modified in place) - lr: Tensor, # Learning rate (scalar tensor) - ef_decay: Tensor, # Error-feedback factor (scalar tensor) - fraction: float, # Fraction of submatrix to orthogonalize (0 < fraction <= 1) - weight_decay: Tensor, # Weight decay (scalar tensor) - epsilon: Tensor, # Epsilon (scalar tensor) + lr: Tensor, # Learning rate (scalar tensor) + ef_decay: Tensor, # Error-feedback factor (scalar tensor) + fraction: float, # Fraction of submatrix to orthogonalize (0 < fraction <= 1) + weight_decay: Tensor, # Weight decay (scalar tensor) + epsilon: Tensor, # Epsilon (scalar tensor) flatten: bool, # Whether to flatten 3D+ tensors to 2D - adjust_lr: Optional[str], # How to adjust learning rate - device_rank: int, # Rank of the current device - world_size: int, # Total number of devices to parallelize over - shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) + adjust_lr: Optional[str], # How to adjust learning rate + device_rank: int, # Rank of the current device + world_size: int, # Total number of devices to parallelize over + shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, verbose: bool = False, @@ -420,7 +420,7 @@ def dion2_update_batch_async( Identical hyperparameters are used for all tensors in the batch. """ assert len(X) == len(G) - assert len(X) == len(M) + assert len(X) == len(M) # Determine selection dimension based on sharding and tensor shape: # For sharded matrices, we align select_dim with shard_dim @@ -462,12 +462,12 @@ def dion2_update_batch_async( assert ( process_group is not None ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" assert ( X[0].size(shard_dim) % world_size == 0 ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." - # Allocate buffers to receive shards of one whole submatrix from other devices + # Allocate buffers to receive shards of one whole submatrix from other devices recv_shards = [torch.empty_like(u) for u in U_selected] work = dist.all_to_all( recv_shards, U_selected, group=process_group, async_op=True @@ -477,7 +477,7 @@ def dion2_update_batch_async( # Concatentate shards to form a whole matrix to orthogonalize # Only submatrix is orthogonalized! - full_submatrix = torch.cat(recv_shards, dim=select_dim) + full_submatrix = torch.cat(recv_shards, dim=select_dim) full_submatrix = muon_update_newton_schulz( full_submatrix, newton_schulz_func, flatten=flatten, epsilon=epsilon ) @@ -503,7 +503,7 @@ def dion2_update_batch_async( single_matrix = U_selected[device_rank] assert not isinstance(single_matrix, DTensor) - + single_ortho = muon_update_newton_schulz( single_matrix, newton_schulz_func, @@ -648,6 +648,7 @@ def dion2_post_orthogonalize( # It only prints once `verbose` is set True _printed_configs: set = set() + def _print_selection_choice( shape: torch.Size, shard_dim: Optional[int], diff --git a/train.py b/train.py index f6ab86f..81654d1 100644 --- a/train.py +++ b/train.py @@ -267,13 +267,12 @@ def init_distributed(dp_size, fs_size, tp_size) -> Optional[DeviceMesh]: print0(f"World size: {world_size}") else: - # Use device mesh for distributed training + # Use device mesh for distributed training # Fill None values with 1 dp_size = dp_size if dp_size is not None else 1 fs_size = fs_size if fs_size is not None else 1 tp_size = tp_size if tp_size is not None else 1 - # Check if we have the right number of GPUs total_gpus = dp_size * fs_size * tp_size assert world_size == total_gpus, ( @@ -812,9 +811,9 @@ def get_lr(it): optimizer_name = hp.optimizer if "dion" in hp.optimizer or "dion2" in hp.optimizer: optimizer_name = f"{hp.ortho_fraction}-{hp.optimizer}" - + run_name = f"({optimizer_name}+{hp.scalar_opt})" - + if device_mesh is not None: dp, fs, tp = device_mesh.size(0), device_mesh.size(1), device_mesh.size(2) run_name += f"_(dp={dp}, fs={fs}, tp={tp})" From df933c35542dfb102d4f52c58e8e9dc86301a7dc Mon Sep 17 00:00:00 2001 From: Kwangjun Ahn Date: Tue, 23 Dec 2025 04:48:02 -0800 Subject: [PATCH 14/14] comment about L1 norm --- dion/dion2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dion/dion2.py b/dion/dion2.py index b682626..68cea3b 100644 --- a/dion/dion2.py +++ b/dion/dion2.py @@ -567,7 +567,8 @@ def dion2_pre_orthogonalize( Update momentum with gradient and compute the input to orthogonalization. More specifically, it does the following steps: - updates the momentum with gradient - - computes the top-k indices to determine submatrices + - computes the top-k indices (according to L1 norm) to determine submatrices + - (other norms can be used such as L2 norm) - does in-place error-feedback decay on the selected submatrices - output submatrices and indices Inputs and outputs should be lists of regular Tensor, not DTensor.