Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ wandb/
output/
.venv/
submit.ipynb
aztool/
aztool/
dion.egg-info/
242 changes: 126 additions & 116 deletions README.md

Large diffs are not rendered by default.

21 changes: 9 additions & 12 deletions configs/dion2_160m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,29 @@ num_iterations: 3000
warmup_ratio: 0.0
warmdown_ratio: 0.2

# — Optimizer & Hyperparameters —
mu: 0.95
weight_decay: 0.01

# — Validation & Checkpointing —
val_loss_every: 125
val_tokens: 10485760
save_every: 0
val_tokens: 10485760

# — Weights & Biases logging —
wandb_project_name: gpt-train
wandb_job_name: null
no_wandb: false

# — Distributed training —
dp_size: null # data‐parallel size
fs_size: null # FSDP size
tp_size: null # DO NOT USE TP for Dion2

# — Miscellaneous flags —
debug: false
no_compile: false
no_triton: false

# — Distributed training —
dp_size: null # data‐parallel size
fs_size: null # FSDP size

# — Optimizer & Hyperparameters —
optimizer: dion2
rank_fraction: 0.5
scalar_opt: lion
mu: 0.95
weight_decay: 0.01
ortho_fraction: 0.5
adjust_lr: spectral_norm
lr: 0.02
21 changes: 9 additions & 12 deletions configs/dion_160m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,30 @@ num_iterations: 3000
warmup_ratio: 0.0
warmdown_ratio: 0.2

# — Optimizer & Hyperparameters —
rank_fraction: 0.125
mu: 0.95
weight_decay: 0.01
oversample: 1.25

# — Validation & Checkpointing —
val_loss_every: 125
val_tokens: 10485760
save_every: 0
val_tokens: 10485760

# — Weights & Biases logging —
wandb_project_name: gpt-train
wandb_job_name: null
no_wandb: false

# — Miscellaneous flags —
debug: false
no_compile: false

# — Distributed training —
dp_size: null # data‐parallel size
fs_size: null # FSDP size
tp_size: null # tensor‐parallel size
replicate_mesh_grad_sync: true

# — Miscellaneous flags —
debug: false
no_compile: false

# — Optimizer & Hyperparameters —
optimizer: dion
scalar_opt: lion
mu: 0.95
weight_decay: 0.01
ortho_fraction: 0.5
lr: 0.02
mixed_precision: true
13 changes: 5 additions & 8 deletions configs/muon_160m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@ num_iterations: 3000
warmup_ratio: 0.0
warmdown_ratio: 0.2

# — Optimizer & Hyperparameters —
mu: 0.95
weight_decay: 0.01

# — Validation & Checkpointing —
val_loss_every: 125
val_tokens: 10485760
save_every: 0
val_tokens: 10485760

# — Weights & Biases logging —
wandb_project_name: gpt-train
Expand All @@ -30,15 +25,17 @@ no_wandb: false

# — Distributed training —
dp_size: null # data‐parallel size
fs_size: null # FSDP size
tp_size: null # DO NOT USE TP for Muon
fs_size: null # FSDP size

# — Miscellaneous flags —
debug: false
no_compile: false
no_triton: false

# — Optimizer & Hyperparameters —
optimizer: muon
scalar_opt: lion
adjust_lr: spectral_norm
mu: 0.95
weight_decay: 0.01
lr: 0.02
13 changes: 5 additions & 8 deletions configs/normuon_160m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@ num_iterations: 3000
warmup_ratio: 0.0
warmdown_ratio: 0.2

# — Optimizer & Hyperparameters —
mu: 0.95
weight_decay: 0.01

# — Validation & Checkpointing —
val_loss_every: 125
val_tokens: 10485760
save_every: 0
val_tokens: 10485760

# — Weights & Biases logging —
wandb_project_name: gpt-train
Expand All @@ -30,15 +25,17 @@ no_wandb: false

# — Distributed training —
dp_size: null # data‐parallel size
fs_size: null # FSDP size
tp_size: null # DO NOT USE TP for NorMuon
fs_size: null # FSDP size

# — Miscellaneous flags —
debug: false
no_compile: false
no_triton: false

# — Optimizer & Hyperparameters —
optimizer: normuon
scalar_opt: lion
adjust_lr: spectral_norm
lr: 0.02
mu: 0.95
weight_decay: 0.01
Loading