From 93b9ef720d72563767b5b6c921093b82d061543d Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Mon, 9 Jun 2025 10:26:25 +0000 Subject: [PATCH 01/12] Init mtraining folder --- .../Phi-3-mini-4k-instruct-LongRoPE-128k.json | 6210 +++++++++++++++++ minference/configs/Qwen2.5_3B_flex_0.90.json | 616 ++ minference/configs/Qwen2.5_3B_flex_0.95.json | 616 ++ ...n2.5_3B_kv_out_v32_fit_o_best_pattern.json | 3530 ++++++++++ minference/dist_ops/__init__.py | 0 minference/dist_ops/dr_striped_attention.py | 584 ++ .../dist_ops/minfer_dr_stripe_triton.py | 404 ++ minference/dist_ops/minfer_dr_striped.py | 471 ++ minference/dist_ops/minfer_striped.py | 411 ++ minference/dist_ops/minfer_striped_triton.py | 284 + minference/dist_ops/minfer_zigzag.py | 331 + minference/dist_ops/moba_zigzag.py | 1066 +++ minference/dist_ops/op_utils/__init__.py | 0 minference/dist_ops/op_utils/moba_utils.py | 636 ++ minference/dist_ops/op_utils/xattn_utils.py | 521 ++ minference/dist_ops/ring_attention.py | 267 + minference/dist_ops/striped_attention.py | 404 ++ minference/dist_ops/utils.py | 523 ++ minference/dist_ops/xattn_zigzag.py | 562 ++ minference/dist_ops/zigzag_attention.py | 412 ++ minference/ops/minference_attn.py | 881 +++ minference/ops/minference_attn_triton.py | 1230 ++++ minference/ops/utils.py | 932 +++ mtraining/.gitignore | 12 + mtraining/README.md | 1 + mtraining/__init__.py | 0 mtraining/models/__init__.py | 29 + .../models/active_param_configs/attn_only.txt | 1 + .../active_param_configs/qk_proj_only.txt | 2 + mtraining/models/phi3/__init__.py | 5 + mtraining/models/phi3/configuration_phi3.py | 227 + .../phi3/lc_config/configuration_phi3.py | 227 + .../phi3/lc_config_mini/configuration_phi3.py | 227 + mtraining/models/phi3/modelling_phi.py | 1185 ++++ mtraining/models/phi3/modelling_phi_legacy.py | 1568 +++++ mtraining/models/qwen2/__init__.py | 7 + mtraining/models/qwen2/configuration_qwen2.py | 196 + .../qwen2/lc_config/configuration_qwen2.py | 196 + .../lc_config_mini/configuration_qwen2.py | 196 + .../qwen2/mi_config/configuration_qwen2.py | 185 + .../models/qwen2/mi_config/modeling_qwen2.py | 1490 ++++ mtraining/models/qwen2/modeling_qwen2.py | 1136 +++ mtraining/models/qwen2/vllm_sparse_qwen2.py | 465 ++ mtraining/models/sparse_ops/.gitignore | 7 + .../mtraining_sparse_ops/__init__.py | 2 + .../mtraining_sparse_ops/minference_config.py | 23 + mtraining/models/sparse_ops/setup.py | 27 + mtraining/requirements.txt | 18 + mtraining/setup.py | 15 + mtraining/setup.sh | 34 + mtraining/train.py | 625 ++ mtraining/trainer.py | 678 ++ mtraining/utils/__init__.py | 2 + mtraining/utils/custom_parallel.py | 467 ++ mtraining/utils/data_utils/__init__.py | 0 mtraining/utils/data_utils/bookcorpus.py | 70 + mtraining/utils/general.py | 137 + mtraining/utils/loss.py | 68 + mtraining/utils/paths.py | 33 + 59 files changed, 30452 insertions(+) create mode 100644 minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json create mode 100644 minference/configs/Qwen2.5_3B_flex_0.90.json create mode 100644 minference/configs/Qwen2.5_3B_flex_0.95.json create mode 100644 minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json create mode 100644 minference/dist_ops/__init__.py create mode 100644 minference/dist_ops/dr_striped_attention.py create mode 100644 minference/dist_ops/minfer_dr_stripe_triton.py create mode 100644 minference/dist_ops/minfer_dr_striped.py create mode 100644 minference/dist_ops/minfer_striped.py create mode 100644 minference/dist_ops/minfer_striped_triton.py create mode 100644 minference/dist_ops/minfer_zigzag.py create mode 100644 minference/dist_ops/moba_zigzag.py create mode 100644 minference/dist_ops/op_utils/__init__.py create mode 100644 minference/dist_ops/op_utils/moba_utils.py create mode 100644 minference/dist_ops/op_utils/xattn_utils.py create mode 100644 minference/dist_ops/ring_attention.py create mode 100644 minference/dist_ops/striped_attention.py create mode 100644 minference/dist_ops/utils.py create mode 100644 minference/dist_ops/xattn_zigzag.py create mode 100644 minference/dist_ops/zigzag_attention.py create mode 100644 minference/ops/minference_attn.py create mode 100644 minference/ops/minference_attn_triton.py create mode 100644 minference/ops/utils.py create mode 100644 mtraining/.gitignore create mode 100644 mtraining/README.md create mode 100644 mtraining/__init__.py create mode 100644 mtraining/models/__init__.py create mode 100644 mtraining/models/active_param_configs/attn_only.txt create mode 100644 mtraining/models/active_param_configs/qk_proj_only.txt create mode 100644 mtraining/models/phi3/__init__.py create mode 100644 mtraining/models/phi3/configuration_phi3.py create mode 100644 mtraining/models/phi3/lc_config/configuration_phi3.py create mode 100644 mtraining/models/phi3/lc_config_mini/configuration_phi3.py create mode 100644 mtraining/models/phi3/modelling_phi.py create mode 100644 mtraining/models/phi3/modelling_phi_legacy.py create mode 100644 mtraining/models/qwen2/__init__.py create mode 100644 mtraining/models/qwen2/configuration_qwen2.py create mode 100644 mtraining/models/qwen2/lc_config/configuration_qwen2.py create mode 100644 mtraining/models/qwen2/lc_config_mini/configuration_qwen2.py create mode 100644 mtraining/models/qwen2/mi_config/configuration_qwen2.py create mode 100644 mtraining/models/qwen2/mi_config/modeling_qwen2.py create mode 100644 mtraining/models/qwen2/modeling_qwen2.py create mode 100644 mtraining/models/qwen2/vllm_sparse_qwen2.py create mode 100644 mtraining/models/sparse_ops/.gitignore create mode 100644 mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py create mode 100644 mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py create mode 100644 mtraining/models/sparse_ops/setup.py create mode 100644 mtraining/requirements.txt create mode 100644 mtraining/setup.py create mode 100755 mtraining/setup.sh create mode 100644 mtraining/train.py create mode 100644 mtraining/trainer.py create mode 100644 mtraining/utils/__init__.py create mode 100644 mtraining/utils/custom_parallel.py create mode 100644 mtraining/utils/data_utils/__init__.py create mode 100644 mtraining/utils/data_utils/bookcorpus.py create mode 100644 mtraining/utils/general.py create mode 100644 mtraining/utils/loss.py create mode 100644 mtraining/utils/paths.py diff --git a/minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json b/minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json new file mode 100644 index 0000000..deab2fa --- /dev/null +++ b/minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json @@ -0,0 +1,6210 @@ +[ + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.3332791030406952 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.4416683614253998 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.4923180937767029 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.3880477547645569 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.38015398383140564 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.34974828362464905 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.4002125859260559 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.8959001898765564 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.44538143277168274 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.364785760641098 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.4276016652584076 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.3688332438468933 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.46328139305114746 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.3348214626312256 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.392171710729599 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.34772568941116333 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.4066217541694641 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.360747754573822 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.34830111265182495 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.47614070773124695 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.47739607095718384 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.376796156167984 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.5052061080932617 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7984429001808167 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.4089720547199249 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.512876033782959 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.45735129714012146 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8220791220664978 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9416881203651428 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.4253023564815521 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.6658170819282532 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.5853910446166992 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.5987299084663391 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.5018280148506165 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.4191092550754547 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5811007618904114 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5408463478088379 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.47298333048820496 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.786112368106842 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.6012060642242432 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.5791159272193909 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9051483869552612 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.58222496509552 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.48059725761413574 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7023431062698364 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.43349939584732056 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5206228494644165 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.6523349285125732 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.605652391910553 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.4064832925796509 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.5884730219841003 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.8767980337142944 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.9045116305351257 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.7437348365783691 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.48772233724594116 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.46409618854522705 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.49005016684532166 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.6720733046531677 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.9864497184753418 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8722768425941467 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.4877939820289612 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.575534462928772 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.42707663774490356 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6121442317962646 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7372177839279175 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7894041538238525 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.5279066562652588 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.4761664569377899 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.4725611209869385 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8257285952568054 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.5990859866142273 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.6145595908164978 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7928207516670227 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.707308292388916 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.7120367884635925 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.687991201877594 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.673988401889801 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8554307222366333 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.755358874797821 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7254285216331482 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.7025570273399353 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9539948105812073 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.6157228946685791 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.733711838722229 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.5779848694801331 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.6323117017745972 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.5607789754867554 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.6604608297348022 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.8311918377876282 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.5358595848083496 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.8005751967430115 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.5391202569007874 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.7308750152587891 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6740477085113525 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.5843774676322937 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.9640750885009766 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9977744221687317 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.997894287109375 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9970235824584961 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.9453125 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9917697310447693 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9911014437675476 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9628551602363586 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9096153378486633 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9951955080032349 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9851846098899841 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.996168315410614 + ], + "16": [ + "vertical_and_slash", + 100, + 800, + 0.96484375 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9974663853645325 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.995905339717865 + ], + "19": [ + "vertical_and_slash", + 30, + 800, + 0.998296320438385 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9968391060829163 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9973934888839722 + ], + "22": [ + "vertical_and_slash", + 30, + 800, + 0.998639702796936 + ], + "23": [ + "vertical_and_slash", + 30, + 800, + 0.9961214661598206 + ], + "24": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "25": [ + "vertical_and_slash", + 30, + 800, + 0.9963303804397583 + ], + "26": [ + "vertical_and_slash", + 100, + 800, + 0.96484375 + ], + "27": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "28": [ + "vertical_and_slash", + 30, + 800, + 0.9944915175437927 + ], + "29": [ + "vertical_and_slash", + 100, + 800, + 0.9453125 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.7727152705192566 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9912976622581482 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 0.9765625 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9993456602096558 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 1.0000362396240234 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9966914057731628 + ], + "4": [ + "vertical_and_slash", + 30, + 800, + 0.8796641826629639 + ], + "5": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9982466697692871 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.96875 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9937134385108948 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9967017769813538 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9959616661071777 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9792157411575317 + ], + "15": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9846665859222412 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9296243190765381 + ], + "18": [ + "vertical_and_slash", + 30, + 800, + 0.9900624752044678 + ], + "19": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "20": [ + "vertical_and_slash", + 30, + 800, + 0.9983267188072205 + ], + "21": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.9900224804878235 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.997448205947876 + ], + "24": [ + "vertical_and_slash", + 30, + 800, + 0.9974275827407837 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9921515583992004 + ], + "26": [ + "vertical_and_slash", + 30, + 800, + 0.9961837530136108 + ], + "27": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9986920952796936 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9839296936988831 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9926077127456665 + ], + "31": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9920685291290283 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9854402542114258 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9949761033058167 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8777214288711548 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9941083192825317 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9981333613395691 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9964705109596252 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9998161196708679 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9989191293716431 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9850229620933533 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9929664731025696 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9970735311508179 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9854613542556763 + ], + "14": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.998539924621582 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9988572597503662 + ], + "17": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "18": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9937215447425842 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9997386336326599 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9956527352333069 + ], + "22": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9980509877204895 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.9594733119010925 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9973872900009155 + ], + "26": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8722200989723206 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9793213605880737 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.9947999119758606 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.999260425567627 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.991969883441925 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9714643955230713 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9889455437660217 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9993591904640198 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9979987740516663 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9925215244293213 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.9974944591522217 + ], + "7": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.733367919921875 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9970583319664001 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.997306227684021 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "12": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9741608500480652 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9867183566093445 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9974902868270874 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.8923004865646362 + ], + "17": [ + "vertical_and_slash", + 30, + 800, + 0.9976084232330322 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9930179119110107 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9756211638450623 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9988991022109985 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9600447416305542 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9966569542884827 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.9705502986907959 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.996631383895874 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.7431671023368835 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9833155274391174 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.995357096195221 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9988372921943665 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6221311688423157 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.962378740310669 + ], + "31": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8013869524002075 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9569835662841797 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9765567779541016 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9659812450408936 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.893607497215271 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9978155493736267 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9035964608192444 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9964177012443542 + ], + "8": [ + "vertical_and_slash", + 500, + 700, + 0.9885393381118774 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9989112019538879 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9980699419975281 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9978247284889221 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9864377379417419 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9934346675872803 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.8716491460800171 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9965083003044128 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9953898191452026 + ], + "17": [ + "vertical_and_slash", + 30, + 800, + 0.906114399433136 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.9365297555923462 + ], + "19": [ + "vertical_and_slash", + 30, + 800, + 0.9918379783630371 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.998047411441803 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9964088797569275 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.9905833601951599 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9906750917434692 + ], + "24": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9951276183128357 + ], + "26": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9964181780815125 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9454283118247986 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.98893141746521 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9951130747795105 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.997049868106842 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.706658124923706 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9980204701423645 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6844134330749512 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9990592002868652 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.8842496871948242 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9998780488967896 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9981557130813599 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9960853457450867 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9998636841773987 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9711626172065735 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9983590245246887 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9998638033866882 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.8798794150352478 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9990912079811096 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9776530265808105 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9992070198059082 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.999956488609314 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9974350929260254 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9929025173187256 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9999449253082275 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9973562359809875 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9999447464942932 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9998283982276917 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9996301531791687 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.99957674741745 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.8067362904548645 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9952501654624939 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.987522304058075 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9994516968727112 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.9991733431816101 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9994316697120667 + ], + "31": [ + "vertical_and_slash", + 500, + 700, + 0.9992942214012146 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9960434436798096 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9954621195793152 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9846372604370117 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.866407036781311 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9988003373146057 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9943948984146118 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9850127696990967 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9895413517951965 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.8989573121070862 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9728281497955322 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9660746455192566 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9935301542282104 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9088181257247925 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9997004270553589 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9992449879646301 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9945877194404602 + ], + "16": [ + "vertical_and_slash", + 500, + 700, + 0.9919179677963257 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9887120127677917 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.8854116797447205 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9977339506149292 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9949434995651245 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9963594675064087 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8057686686515808 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7062434554100037 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.714179515838623 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9638258814811707 + ], + "26": [ + "vertical_and_slash", + 500, + 700, + 0.9974403381347656 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9426991939544678 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9981192946434021 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.997894287109375 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9979783892631531 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9956291317939758 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9942532777786255 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9382870197296143 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9994586706161499 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9935740232467651 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9681561589241028 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.7282125949859619 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9943098425865173 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.6868635416030884 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9987292885780334 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.8973380923271179 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.995192289352417 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.6858950257301331 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9975180625915527 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9925879836082458 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.9210449457168579 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9997319579124451 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9917876720428467 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.6827043890953064 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9853598475456238 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.9982296824455261 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9968504905700684 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9502132534980774 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9775718450546265 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7796767354011536 + ], + "24": [ + "vertical_and_slash", + 100, + 800, + 0.76953125 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9922406673431396 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9383352994918823 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.8227579593658447 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9977031946182251 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.9920535087585449 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9953324198722839 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6449344754219055 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.652643620967865 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9490343928337097 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.6628352403640747 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9904823899269104 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.7459101676940918 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.8223817348480225 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.8061914443969727 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9979249238967896 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.956091582775116 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9958360195159912 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.6650176644325256 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.946822464466095 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8161380887031555 + ], + "13": [ + "vertical_and_slash", + 500, + 700, + 0.9891030192375183 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9935292601585388 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7558044195175171 + ], + "16": [ + "vertical_and_slash", + 500, + 700, + 0.9654762744903564 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.57756507396698 + ], + "18": [ + "vertical_and_slash", + 500, + 700, + 0.9946773648262024 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.6871338486671448 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9582348465919495 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.8474758267402649 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.9841222763061523 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9883868098258972 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.8085432052612305 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.8961470127105713 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.7980116009712219 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9630904197692871 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9312359094619751 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.7192952036857605 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9928464889526367 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6763450503349304 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9938258528709412 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9992395043373108 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9985423684120178 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5909091234207153 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.9045535326004028 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.5327474474906921 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7697188258171082 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9967195987701416 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7490115761756897 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9701148271560669 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9769464135169983 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9512021541595459 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.945134162902832 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9831423759460449 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5774247050285339 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.992691695690155 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9788296222686768 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.9921766519546509 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9580382108688354 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.8484612703323364 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9827266931533813 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.6860563158988953 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9984534978866577 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.7839770913124084 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.7702468633651733 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9819672107696533 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.7506661415100098 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.8615745902061462 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9768449068069458 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9400221109390259 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.99903404712677 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.814863383769989 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9924616813659668 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7962270975112915 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.8446550965309143 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.986892819404602 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9911617636680603 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9796133637428284 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9944987893104553 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.989579975605011 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9553632736206055 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9968169331550598 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9912428855895996 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9946438074111938 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.960080087184906 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9448127150535583 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.992832362651825 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9532816410064697 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9048861265182495 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.994652509689331 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9620020389556885 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.6745208501815796 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.5609733462333679 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.8874548673629761 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.6557503342628479 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.6937400102615356 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9920935034751892 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9974891543388367 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.771544873714447 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9795010685920715 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9977417588233948 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.9876147508621216 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.7052534222602844 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9702804684638977 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.8113299608230591 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.75982666015625 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9722777009010315 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.6454917788505554 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9591151475906372 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9245743751525879 + ], + "6": [ + "vertical_and_slash", + 100, + 750, + 0.672450602054596 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.8690704703330994 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.77734375 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9322983622550964 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6070910692214966 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.8282676935195923 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7981728315353394 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9669613838195801 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.7259189486503601 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.7247616052627563 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.927749514579773 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.9843335747718811 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.939363956451416 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.9057216048240662 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.8851215839385986 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9476007223129272 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.8492496013641357 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9673789143562317 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9385042190551758 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9418122172355652 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.890839159488678 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9094856977462769 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9846156239509583 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.926313042640686 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8919520974159241 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.7341398596763611 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9748796820640564 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.918652355670929 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.5971909165382385 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.7076131105422974 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.84599369764328 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.548804521560669 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6747919917106628 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.9607553482055664 + ], + "8": [ + "vertical_and_slash", + 500, + 700, + 0.9529062509536743 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9745543003082275 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.5900142788887024 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.6382952332496643 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8348908424377441 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.8871970772743225 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6412845849990845 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9772831201553345 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.7201339602470398 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9578301310539246 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.97627192735672 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.802989661693573 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9934332370758057 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.988477885723114 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.977401614189148 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.8183449506759644 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.8428875803947449 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9019399285316467 + ], + "26": [ + "vertical_and_slash", + 500, + 700, + 0.850577175617218 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.8797270655632019 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9850448369979858 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.7459467649459839 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.6797484159469604 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.7080639004707336 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.684343695640564 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6700385808944702 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6516374945640564 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9504920840263367 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.9940043687820435 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8019426465034485 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.67058926820755 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9486842751502991 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9951285719871521 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.7573692202568054 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9824524521827698 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.6431483030319214 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7401513457298279 + ], + "13": [ + "vertical_and_slash", + 500, + 700, + 0.9602683186531067 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5100213289260864 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.620374321937561 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9879629015922546 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9683711528778076 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9918504953384399 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9663371443748474 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.9958233833312988 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.731646716594696 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9967772364616394 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7126162052154541 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.73592609167099 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9757681488990784 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.821488082408905 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.940662145614624 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.7413780689239502 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.855651319026947 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.957796573638916 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9116591215133667 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9933403730392456 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9730038642883301 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9607142210006714 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9766934514045715 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.8421128988265991 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9946009516716003 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6585829257965088 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.8131377100944519 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9548000693321228 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9604054093360901 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9921189546585083 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9823326468467712 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9664893746376038 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.991500735282898 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9875308275222778 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9863171577453613 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9744452238082886 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.8115602731704712 + ], + "18": [ + "vertical_and_slash", + 500, + 700, + 0.8509863018989563 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.7890191674232483 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9685430526733398 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9879798293113708 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8662706017494202 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9373542666435242 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9943336248397827 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9302772283554077 + ], + "26": [ + "vertical_and_slash", + 500, + 700, + 0.977992832660675 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.932984471321106 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9221320152282715 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6031616926193237 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9101414680480957 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.9133468270301819 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9161027669906616 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8264802694320679 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9612565636634827 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.8154476881027222 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.870027482509613 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.8858229517936707 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.732231616973877 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.8734871745109558 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7596957683563232 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.8638753890991211 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.8574302792549133 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.7212074398994446 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8646799921989441 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.6482542753219604 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.83883136510849 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8585593104362488 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9678143858909607 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.7909027934074402 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.8701796531677246 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.7488895654678345 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.8781315088272095 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.7022296190261841 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.7967407703399658 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7607766389846802 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.5860172510147095 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.8311766982078552 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.9960868954658508 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9676917195320129 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9716269373893738 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.7922632098197937 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9029964804649353 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.7617987990379333 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7531226873397827 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9829939007759094 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9867648482322693 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.7495420575141907 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8411062359809875 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9238793253898621 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8017065525054932 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.949862539768219 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7137445211410522 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8886378407478333 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6042924523353577 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9610161781311035 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.6790781617164612 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8115764856338501 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.8262994885444641 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.5916704535484314 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.817267894744873 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.8543776869773865 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.875899612903595 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.804548442363739 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.7843059301376343 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.9910483360290527 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.7706062197685242 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.8355559706687927 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.8440632820129395 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.8883750438690186 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.5762335062026978 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.6402088403701782 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9510595202445984 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6832257509231567 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8940309882164001 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6904930472373962 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9712978005409241 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.715650200843811 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9092068076133728 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5512199401855469 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9947230815887451 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.942888617515564 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8079208135604858 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934179186820984 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9797422885894775 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9305815100669861 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.988544762134552 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9259185791015625 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9055561423301697 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9950724244117737 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9936328530311584 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7676860094070435 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9986890554428101 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9774577021598816 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9445663094520569 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9051742553710938 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9963957071304321 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.6996874809265137 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.9648709297180176 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.8904076814651489 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.6945720314979553 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9790938496589661 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9644208550453186 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9994762539863586 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.8030149936676025 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.9580996632575989 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9973891973495483 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9492622017860413 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9596408009529114 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9903782606124878 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.985963761806488 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.8823582530021667 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.78907310962677 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.7436686754226685 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9900132417678833 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9172103404998779 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9941288232803345 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9992177486419678 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9879436492919922 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9889773726463318 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9961833953857422 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.992790699005127 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.879410445690155 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9506082534790039 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9580295085906982 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9992238879203796 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.8985172510147095 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9653815031051636 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.8116862177848816 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.8597212433815002 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8310312032699585 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9763169288635254 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9825117588043213 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9893282651901245 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9984802007675171 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.9936201572418213 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.966774582862854 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9496166110038757 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9060468077659607 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9787821769714355 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.98621666431427 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9871621131896973 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.998849630355835 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9951006174087524 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9684059619903564 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9988571405410767 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9973605275154114 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9547178745269775 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9439932107925415 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.761074960231781 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9970728158950806 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9856483340263367 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9948339462280273 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9889204502105713 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.991661548614502 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9836190938949585 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9960607290267944 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9184228181838989 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9532352685928345 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9791851043701172 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9993872046470642 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.8049466013908386 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9487335681915283 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9942803382873535 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.9795239567756653 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9796655774116516 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.9073926210403442 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8279032707214355 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.928718090057373 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9749822616577148 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.8956699371337891 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.7362069487571716 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.948788583278656 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9491262435913086 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.8548357486724854 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9972718954086304 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.995877742767334 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9962888360023499 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.950820803642273 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.940021276473999 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.97650146484375 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9919331669807434 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.7328993082046509 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.992123007774353 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9887773394584656 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9808080792427063 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.9803031086921692 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7681567072868347 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9663677215576172 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9978005886077881 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9956686496734619 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9876529574394226 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9083545207977295 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9825258851051331 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9929331541061401 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9937247037887573 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9849582314491272 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9303218722343445 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.9582874774932861 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9650914669036865 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9943416714668274 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9150344133377075 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8997454047203064 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9907521605491638 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.9935974478721619 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9944353103637695 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9993149638175964 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9968475103378296 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.7440210580825806 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.99021315574646 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9938840270042419 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.93567955493927 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9877127408981323 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9873458743095398 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9998465776443481 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9908269047737122 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9990859031677246 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.8770437836647034 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9989800453186035 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9958862662315369 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9739276766777039 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9918177723884583 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9817363023757935 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9980490207672119 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9854499101638794 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.9956621527671814 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9646536111831665 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.8399244546890259 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9599056243896484 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9969561100006104 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.8741656541824341 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.9881818890571594 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9986366629600525 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.8203835487365723 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.916657567024231 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9909099340438843 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.9891338348388672 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9982934594154358 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.7269482016563416 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.97837895154953 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9985787868499756 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9378053545951843 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.8923203945159912 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9933837056159973 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9831008911132812 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9890069961547852 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9977155923843384 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9636794328689575 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9993752837181091 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9918390512466431 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9965466856956482 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.979774534702301 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9794058799743652 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.998450517654419 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.7310967445373535 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9953096508979797 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9857947826385498 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.987230658531189 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9985311031341553 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9923253655433655 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9921882152557373 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9417669773101807 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.9951248168945312 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.9957342743873596 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.8214721083641052 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9924106001853943 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9996931552886963 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9912320375442505 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9739575982093811 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9948337078094482 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9912586808204651 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9957679510116577 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9885770678520203 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9814969301223755 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9366181492805481 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.9948270320892334 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9994694590568542 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9889045357704163 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9929386973381042 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9881108999252319 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9796789288520813 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9889234304428101 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9873424768447876 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9901317358016968 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9893797039985657 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9819779992103577 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.989522397518158 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9819537997245789 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.9925962686538696 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9989944696426392 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.9997721314430237 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9876223802566528 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9952347874641418 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9843642711639404 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9960111975669861 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.6954624652862549 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.970451295375824 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.991379976272583 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8738142848014832 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9786747694015503 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.997829258441925 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.990919291973114 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9968075156211853 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9982627630233765 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9957785606384277 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.974946141242981 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.995257556438446 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.8554062247276306 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9880555272102356 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9956945776939392 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9789673089981079 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.991595447063446 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9686179161071777 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9943218231201172 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9565165042877197 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9816807508468628 + ], + "16": [ + "vertical_and_slash", + 500, + 700, + 0.9969928860664368 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9579240679740906 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.7692553400993347 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9934751987457275 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.996086597442627 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9795346260070801 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9099371433258057 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9606084823608398 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9944002032279968 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9969326257705688 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.943459689617157 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.9907713532447815 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9855557084083557 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.911635160446167 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9951326847076416 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9821126461029053 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9654005765914917 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.999093770980835 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.99101722240448 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9934068918228149 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9982517957687378 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9996331334114075 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9952266216278076 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.880059003829956 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9872965216636658 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9972522854804993 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9663414359092712 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9989503622055054 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9980217218399048 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9978732466697693 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.990183413028717 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9975069761276245 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9249386787414551 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.9966362118721008 + ], + "18": [ + "vertical_and_slash", + 500, + 700, + 0.9951748847961426 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9986919164657593 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9912869930267334 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.9970594644546509 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.998475193977356 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9993215799331665 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9980448484420776 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9916543364524841 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.980556845664978 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9921435117721558 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9989830255508423 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9973907470703125 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9833565354347229 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9759599566459656 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.9991001486778259 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9909220933914185 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9492173194885254 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9961900115013123 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.8713532090187073 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9998117089271545 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.998404324054718 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9998683333396912 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9000679850578308 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9757489562034607 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9937180876731873 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9938128590583801 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9991152882575989 + ], + "13": [ + "vertical_and_slash", + 500, + 700, + 0.9996658563613892 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.9966751933097839 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.999762773513794 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9985853433609009 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.7297303080558777 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9985169768333435 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9998067617416382 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.8747161030769348 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9923343658447266 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9940261840820312 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9998047351837158 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9748029112815857 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9991946816444397 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.8475115299224854 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9997408390045166 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9990043044090271 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.8996012806892395 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.9358092546463013 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9944736361503601 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8248796463012695 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7729880213737488 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.7208629250526428 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.7399256825447083 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.7422590851783752 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.949510395526886 + ], + "6": [ + "vertical_and_slash", + 100, + 750, + 0.8432893753051758 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9128398299217224 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.8389910459518433 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8408676981925964 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.8649052381515503 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9523816704750061 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7945832014083862 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.8890054225921631 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.755133330821991 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9407838582992554 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.978395938873291 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.841075599193573 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.7965966463088989 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.7598339319229126 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.7436662316322327 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.8721699714660645 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.872313916683197 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9902216792106628 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.7798812985420227 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9631245732307434 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.845567524433136 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8043644428253174 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9540744423866272 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.9055390357971191 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9507457613945007 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.827296793460846 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9915655255317688 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9915629029273987 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9245554804801941 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9129937291145325 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9958124160766602 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.806128203868866 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8036627769470215 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9017021656036377 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9740359783172607 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.968455970287323 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9159662127494812 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8428217172622681 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9581736326217651 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.970429539680481 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9745091199874878 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.894728422164917 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9496183395385742 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9325363039970398 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.9315022826194763 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.8017199635505676 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.9268004298210144 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.8929623365402222 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8346715569496155 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.8660512566566467 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.9183820486068726 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.8379555344581604 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.911184549331665 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8829504251480103 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9138942956924438 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.872784435749054 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9097738862037659 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.9275451898574829 + ] + } +] \ No newline at end of file diff --git a/minference/configs/Qwen2.5_3B_flex_0.90.json b/minference/configs/Qwen2.5_3B_flex_0.90.json new file mode 100644 index 0000000..36f48de --- /dev/null +++ b/minference/configs/Qwen2.5_3B_flex_0.90.json @@ -0,0 +1,616 @@ +[ + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0] + }, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0] + } +] \ No newline at end of file diff --git a/minference/configs/Qwen2.5_3B_flex_0.95.json b/minference/configs/Qwen2.5_3B_flex_0.95.json new file mode 100644 index 0000000..2cacba4 --- /dev/null +++ b/minference/configs/Qwen2.5_3B_flex_0.95.json @@ -0,0 +1,616 @@ +[ + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0] + }, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0] + } +] \ No newline at end of file diff --git a/minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json b/minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json new file mode 100644 index 0000000..fdb2922 --- /dev/null +++ b/minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json @@ -0,0 +1,3530 @@ +[ + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9872207641601562 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9784929752349854 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.7595849633216858 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5381054878234863 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9863664507865906 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9912353157997131 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7160804867744446 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9073030352592468 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9161261916160583 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9784228205680847 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9789554476737976 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.7867575883865356 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 1.0000020265579224 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9307636618614197 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6895971298217773 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7491968870162964 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9951784610748291 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9995638132095337 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9881429076194763 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9951925277709961 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.928062379360199 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9995735883712769 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7747997045516968 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9893845319747925 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9328423738479614 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.7227432131767273 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6669939160346985 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.955822765827179 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.6157850623130798 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8225603103637695 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6094294786453247 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7056097388267517 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8943619728088379 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6963416337966919 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9629008173942566 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.866447389125824 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5282332897186279 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9982369542121887 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9979943633079529 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9979172945022583 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.942166268825531 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9923297166824341 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9751147031784058 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8978350758552551 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8243312239646912 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9721394181251526 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.93731689453125 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9794054627418518 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7129239439964294 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9804909229278564 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9835291504859924 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9893701076507568 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9409563541412354 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8059223890304565 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6498631238937378 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.984248697757721 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7962363362312317 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.868658721446991 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.8754917979240417 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.8955696821212769 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9082641005516052 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8178426623344421 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.8291682004928589 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8030994534492493 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9872316122055054 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9957523345947266 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9542893171310425 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9896659255027771 + ], + "4": [ + "vertical_and_slash", + 30, + 800, + 0.9950734376907349 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8253392577171326 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8497561812400818 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.9906441569328308 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9813032746315002 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9744712114334106 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.809856116771698 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9735696911811829 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8704012036323547 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.965289294719696 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9198070168495178 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9896268844604492 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9742131233215332 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9894583821296692 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9873966574668884 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9833617210388184 + ], + "4": [ + "vertical_and_slash", + 30, + 800, + 0.9105245471000671 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.734259843826294 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.9877724051475525 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.9896732568740845 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8814374208450317 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9992178678512573 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9980509877204895 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9092496037483215 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8476247191429138 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9594801664352417 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8697033524513245 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9915944933891296 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 0.95703125 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.5575676560401917 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9087390303611755 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.9765625 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9809911847114563 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.5668226480484009 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.978988766670227 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9430385828018188 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.7421875 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9403589963912964 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9878966808319092 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9328376650810242 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.8343550562858582 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.959410548210144 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9758256673812866 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.933525562286377 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6686467528343201 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9911661744117737 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5714142322540283 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8095535635948181 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9295337796211243 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.965142011642456 + ], + "7": [ + "vertical_and_slash", + 100, + 800, + 0.984375 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.6681234240531921 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8945913910865784 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6786034107208252 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9802227020263672 + ], + "12": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9555180072784424 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9749921560287476 + ], + "15": [ + "vertical_and_slash", + 100, + 800, + 0.9921875 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8633137345314026 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9670861959457397 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.8114507794380188 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9675626158714294 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9122037291526794 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9735010862350464 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9719548225402832 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9510305523872375 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.72588711977005 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9162058234214783 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.924541175365448 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9764450788497925 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9914652705192566 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.7905170321464539 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9880024790763855 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9844828844070435 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9420890212059021 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.986860454082489 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9866741299629211 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.7924719452857971 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.890625 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9799174666404724 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9916488528251648 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.993992030620575 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9875913262367249 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.980666995048523 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9952725172042847 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9930256605148315 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9982009530067444 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9283435940742493 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9926015138626099 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9950297474861145 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.989981472492218 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9963133335113525 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934775233268738 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9720687866210938 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.957918107509613 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8732808828353882 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9542519450187683 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934373497962952 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9688447713851929 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9413594007492065 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9796752333641052 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9452784657478333 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9403716921806335 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9615768790245056 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.760350227355957 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8651421666145325 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9908298254013062 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.818602979183197 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9621649980545044 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.7571771144866943 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9852563738822937 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9928317070007324 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7078589797019958 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9356410503387451 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8330709338188171 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9761743545532227 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9746711254119873 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9730107188224792 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9346276521682739 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9564436674118042 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9630159735679626 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9073519706726074 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8938577175140381 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9779987931251526 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9370880126953125 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8135497570037842 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9516273736953735 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.874685525894165 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8642981648445129 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9647448658943176 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.7395755052566528 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9181990027427673 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.7171043157577515 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8813474774360657 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9861239790916443 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9770793318748474 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9610176086425781 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9338920712471008 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9416564106941223 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.991210401058197 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9882159233093262 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.736187219619751 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8281543254852295 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9498741030693054 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9149094820022583 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.9877381920814514 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.963042140007019 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9886980652809143 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9890977740287781 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9927736520767212 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.992214024066925 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9973360300064087 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.7854554653167725 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9945513606071472 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9653095602989197 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8165479898452759 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9046997427940369 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8763824701309204 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8182893991470337 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9363685846328735 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 0.734375 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9775119423866272 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9895726442337036 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9737070202827454 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.996068000793457 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.997230589389801 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9761516451835632 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9949500560760498 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9936166405677795 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9869099855422974 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9887930154800415 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9879409670829773 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.957123875617981 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 0.98828125 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.4647713899612427 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9909580945968628 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.9757564067840576 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5521421432495117 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7780187129974365 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9887256026268005 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9927332401275635 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9805054664611816 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9525687098503113 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9362225532531738 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9488365054130554 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9525135159492493 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934394955635071 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9532470703125 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9188738465309143 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9849047660827637 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9228449463844299 + ], + "5": [ + "vertical_and_slash", + 100, + 800, + 0.9765625 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9707450866699219 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9929892420768738 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.964901864528656 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.911367654800415 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9818339943885803 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9837478399276733 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9615333080291748 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9666763544082642 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9545288681983948 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9649417400360107 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9158878326416016 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9285635948181152 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9884502291679382 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8363761901855469 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9531059265136719 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9884499907493591 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9524633884429932 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9358732104301453 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8582175374031067 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.3922925889492035 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.5313135385513306 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.953125 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8960450291633606 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.90234375 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6443539261817932 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.829773485660553 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.5914504528045654 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8983972668647766 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.930306077003479 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8418871164321899 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.9140625 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8820360898971558 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7236220836639404 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5575733184814453 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9963138699531555 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9883040189743042 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9783397912979126 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9933704733848572 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9880709648132324 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9847649931907654 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9817938804626465 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9038569927215576 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9423895478248596 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9898338913917542 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8854114413261414 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9820103049278259 + ], + "5": [ + "vertical_and_slash", + 100, + 800, + 0.98828125 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6622527837753296 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5836654901504517 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7537979483604431 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9120598435401917 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9560258388519287 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9937338829040527 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9111098051071167 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9436591863632202 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9429888129234314 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 0.8515625 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.620391845703125 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6681154370307922 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9479513764381409 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5289033651351929 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.43187281489372253 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8812884092330933 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.79611736536026 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9973558783531189 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.734375 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6380698680877686 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.9609375 + ], + "12": [ + "vertical_and_slash", + 100, + 800, + 0.8359375 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8900502324104309 + ], + "14": [ + "vertical_and_slash", + 100, + 800, + 0.8984375 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7215483784675598 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8019538521766663 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8427147269248962 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.6292986273765564 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9548527002334595 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8857505321502686 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8712131381034851 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8540765643119812 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5264020562171936 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8968150615692139 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.6485419273376465 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.8069987893104553 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8020429015159607 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8054234981536865 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.725652813911438 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.46037647128105164 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9644275903701782 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7860593199729919 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8588574528694153 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9157812595367432 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8626066446304321 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8797851800918579 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.7836940884590149 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8759902715682983 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.8079484701156616 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.7734375 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.80859375 + ], + "10": [ + "vertical_and_slash", + 100, + 800, + 0.91796875 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.93359375 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.5005961656570435 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8044103384017944 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.6477628946304321 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.5467575192451477 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9432184100151062 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9644391536712646 + ], + "2": [ + "vertical_and_slash", + 100, + 800, + 0.83203125 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9638855457305908 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9378725290298462 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8560249209403992 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9667811989784241 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9581833481788635 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9911800622940063 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.7890625 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9970740675926208 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.9453125 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8719268441200256 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9973757863044739 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9817723631858826 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9943931102752686 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7833037376403809 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.985275149345398 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.8808413147926331 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.6342004537582397 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.6783415675163269 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.8690900206565857 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9161447286605835 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9136660099029541 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9813501834869385 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9898714423179626 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9806197285652161 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8513939380645752 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9925402402877808 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.98046875 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9726189374923706 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9964751601219177 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8871311545372009 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6332476735115051 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6504296660423279 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8371340036392212 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5707467198371887 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.49299511313438416 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6483507752418518 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9506110548973083 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.8480803966522217 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.82421875 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9129166007041931 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.5975048542022705 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8599441647529602 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.94140625 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.7465125918388367 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9268635511398315 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9838484525680542 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 0.96875 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9804697632789612 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9868910908699036 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.953125 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8383246660232544 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.7157800793647766 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9093014001846313 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.82421875 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.951353132724762 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.98046875 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.5795325040817261 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9869730472564697 + ], + "14": [ + "vertical_and_slash", + 100, + 800, + 0.953125 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.6627390384674072 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9301889538764954 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9623850584030151 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.925979495048523 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9278422594070435 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9768559336662292 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9585739374160767 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9396020174026489 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.8997651934623718 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 1.000001072883606 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.89453125 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9783903360366821 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9423955678939819 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.6527382135391235 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9103218913078308 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.829939067363739 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.982806921005249 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9865686893463135 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9607634544372559 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9966762065887451 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 0.75 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8039277791976929 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9704546332359314 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9315423965454102 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.8952665328979492 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8250367045402527 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.953521728515625 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.8773703575134277 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9780517816543579 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9766848683357239 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9704862236976624 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9836187958717346 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9668617844581604 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9616305828094482 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9781763553619385 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9224084615707397 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9049281477928162 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9796241521835327 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9083701968193054 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9923230409622192 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9486480951309204 + ], + "8": [ + "vertical_and_slash", + 500, + 700, + 0.9276089668273926 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9481350183486938 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9749733805656433 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9456705451011658 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9752734303474426 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.955278754234314 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9254546761512756 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9746971726417542 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9719505310058594 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.8964815735816956 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8442646265029907 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9568673968315125 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.94114750623703 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 0.96875 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9370027780532837 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.91015625 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9786306023597717 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.7740182876586914 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9646586179733276 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.832630455493927 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.94921875 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5829520225524902 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9215387105941772 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9220553040504456 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9111120700836182 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.8448940515518188 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9479627013206482 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9360203146934509 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.806208074092865 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.8499483466148376 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9351169466972351 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.887176513671875 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9972283244132996 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.7344715595245361 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9290981888771057 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9191303849220276 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9607530832290649 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6650398373603821 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9041045308113098 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9288094639778137 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.922483503818512 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.8560412526130676 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9515807628631592 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.910995364189148 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9611515402793884 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9130395650863647 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.8693947196006775 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9321247935295105 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.7624196410179138 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9113157391548157 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.8822183012962341 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.940976083278656 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9429124593734741 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9557855129241943 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.7963366508483887 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9815316796302795 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9569784998893738 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9643719792366028 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9694581627845764 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9616712927818298 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9564018249511719 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9240798354148865 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9618653059005737 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9393341541290283 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9590299129486084 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9623062014579773 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9482530355453491 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9658593535423279 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9724211096763611 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9585490822792053 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9729295969009399 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9484567046165466 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7411888241767883 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.8387667536735535 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.7810403108596802 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.8578725457191467 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.8868502974510193 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7327648401260376 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9077197313308716 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9110754728317261 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8923226594924927 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.7206286191940308 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.8617163300514221 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8827745914459229 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.7372896075248718 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9529178142547607 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.8600607514381409 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9317379593849182 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9660685658454895 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.932490885257721 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9033127427101135 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9372511506080627 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9416565299034119 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9347256422042847 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9515694379806519 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9580468535423279 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9769356846809387 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6504788398742676 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9479199051856995 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9721158146858215 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9033265113830566 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9400415420532227 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.8886585831642151 + ] + } +] \ No newline at end of file diff --git a/minference/dist_ops/__init__.py b/minference/dist_ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minference/dist_ops/dr_striped_attention.py b/minference/dist_ops/dr_striped_attention.py new file mode 100644 index 0000000..242f506 --- /dev/null +++ b/minference/dist_ops/dr_striped_attention.py @@ -0,0 +1,584 @@ +import os +import torch +import torch.distributed as dist + +from typing import List +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import ( + RingComm, update_out_and_lse, get_default_args, + shuffle_striped_input, recover_striped_output +) + +def get_inner_group(): + rank = dist.get_rank() + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + inner_group = [i for i in range(local_world_size)] + + rank_offset = (rank // local_world_size) * local_world_size + inner_group = [rank_offset + i for i in inner_group] + + return inner_group + +def get_outer_group(): + local_rank = int(os.environ["LOCAL_RANK"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + world_size = dist.get_world_size() + + outer_group = [] + i = local_rank + while i < world_size: + outer_group.append(i) + i += local_world_size + + return outer_group + +def stripe_fwd_inner( + process_group, + outer_step: int, + outer_rank: int, + inner_ring_list: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + inner_comm = RingComm(process_group, False, inner_ring_list) + inner_rank = int(os.environ["LOCAL_RANK"]) + num_inner_steps = len(inner_ring_list) + + out, lse = None, None + next_k, next_v = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + def forward(q_, k_, v_, dropout_p_, softmax_scale_, causal_, alibi_slopes_): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q_, + "k": k_, + "v": v_, + "dropout_p": dropout_p_, + "softmax_scale": softmax_scale_, + "causal": causal_, + "alibi_slopes": alibi_slopes_, + "return_softmax": True and dropout_p_ > 0, + } + ) + + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + if outer_step == 0 and inner_step > inner_rank: + block_out, block_lse = forward( + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], + dropout_p, softmax_scale, causal, alibi_slopes, + ) + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, slice_=(slice(None), slice(granularity, None)) + ) + else: + block_out, block_lse = forward( + q, k, v, + dropout_p, softmax_scale, causal, alibi_slopes, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + +def stripe_fwd_outer( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + inner_ring_list: List[int]=None, + outer_ring_list: List[int]=None, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" + outer_comm = RingComm(process_group, False, outer_ring_list) + + global_rank = dist.get_rank() + outer_rank = outer_ring_list.index(global_rank) + num_outer_steps = len(outer_ring_list) + + out = None + lse = None + + next_k, next_v = None, None + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + + if outer_step <= outer_rank: + block_out, block_lse = stripe_fwd_inner( + process_group, outer_step, outer_rank, inner_ring_list, + q, k, v, + softmax_scale, + granularity, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + # Before the step index goes beyond the current rank, the received KV indices are not greater than those of the Q in the current rank + # After the step index goes beyond the current rank, only the KV indices before the last granularity are no greater than those of the Q after the first granularity + # this conclusion holds after the step index goes beyond the current rank (not just step index == current rank) + block_out, block_lse = stripe_fwd_inner( + process_group, outer_step, outer_rank, inner_ring_list, + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], + softmax_scale, + granularity, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + ) + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, slice_=(slice(None), slice(granularity, None)) + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def stripe_backward_inner( + process_group, + outer_step, inner_ring_list, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert ( + causal + ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + kv_comm = RingComm(process_group, False, inner_ring_list) + d_kv_comm = RingComm(process_group, False, inner_ring_list) + + inner_rank = int(os.environ["LOCAL_RANK"]) + num_inner_step = len(inner_ring_list) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + for inner_step in range(num_inner_step): + if inner_step + 1 != num_inner_step: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + shift_causal = outer_step == 0 and inner_step > inner_rank + softmax_lse_1 = None + + def backward( + dout_, + q_, k_, v_, out_, softmax_lse_, + block_dq_buffer_, block_dk_buffer_, block_dv_buffer_, + ): + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout_, + "q": q_, + "k": k_, + "v": v_, + "out": out_, + "softmax_lse": softmax_lse_, + "dq": block_dq_buffer_, + "dk": block_dk_buffer_, + "dv": block_dv_buffer_, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + if not shift_causal: + backward( + dout, q, k, v, out, softmax_lse, + block_dq_buffer, block_dk_buffer, block_dv_buffer + ) + else: + if softmax_lse_1 is None: + # lazy init, since the last rank does not need softmax_lse_1 + softmax_lse_1 = softmax_lse[:, :, granularity:].contiguous() + backward( + dout[:, granularity:], + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], + out[:, granularity:], softmax_lse_1, + block_dq_buffer[:, granularity:], block_dk_buffer[:, :-granularity], block_dv_buffer[:, :-granularity] + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + if not shift_causal: + dq += block_dq_buffer + else: + dq[:, granularity:] += block_dq_buffer[:, granularity:] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if not shift_causal: + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dk[:, :-granularity] += block_dk_buffer[:, :-granularity] + dv[:, :-granularity] += block_dv_buffer[:, :-granularity] + + if inner_step + 1 != num_inner_step: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +def stripe_backward_outer( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + inner_ring_list: List[int], + outer_ring_list: List[int], + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert ( + causal + ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + + outer_kv_comm = RingComm(process_group, False, outer_ring_list) + outer_dkv_comm = RingComm(process_group, False, outer_ring_list) + + global_rank = dist.get_rank() + outer_rank = outer_ring_list.index(global_rank) + num_outer_steps = len(outer_ring_list) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + + softmax_lse_1 = None + outer_shift = outer_step > outer_rank + + if not outer_shift: + block_dq_buffer, block_dk_buffer, block_dv_buffer = stripe_backward_inner( + process_group, outer_step, inner_ring_list, + dout, q, k, v, out, + softmax_lse, softmax_scale, granularity, dropout_p, + causal, window_size, alibi_slopes, deterministic, + ) + else: + if softmax_lse_1 is None: + # lazy init, since the last rank does not need softmax_lse_1 + softmax_lse_1 = softmax_lse[:, :, granularity:].contiguous() + block_dq_buffer, block_dk_buffer, block_dv_buffer = stripe_backward_inner( + process_group, outer_step, inner_ring_list, + dout[:, granularity:], + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], out[:, granularity:], + softmax_lse_1, softmax_scale, granularity, dropout_p, + causal, window_size, alibi_slopes, deterministic, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + if not outer_shift: + dq += block_dq_buffer + else: + dq[:, granularity:] += block_dq_buffer + + outer_dkv_comm.wait() + + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if not outer_shift: + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dk[:, :-granularity] += block_dk_buffer + dv[:, :-granularity] += block_dv_buffer + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_dkv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_dkv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +class DRStripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + inner_ring_list = get_inner_group() # ranks in the current node, length = num of cards within this node + outer_ring_list = get_outer_group() # corresponding ranks in other nodes, length = num of nodes + + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + out, softmax_lse = stripe_fwd_outer( + group, + q, + k, + v, + softmax_scale=softmax_scale, + granularity=granularity, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + inner_ring_list=inner_ring_list, + outer_ring_list=outer_ring_list, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.inner_ring_list = inner_ring_list + ctx.outer_ring_list = outer_ring_list + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse = ctx.saved_tensors + inner_ring_list, outer_ring_list = ctx.inner_ring_list, ctx.outer_ring_list + dq, dk, dv = stripe_backward_outer( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + ctx.softmax_scale, + inner_ring_list=inner_ring_list, + outer_ring_list=outer_ring_list, + granularity=ctx.granularity, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def dr_stripe_flash_attn_qkvpacked_func( + qkv, # [B, N, 3, H, D] + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return DRStripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def dr_stripe_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return DRStripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def dr_stripe_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return DRStripeFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/minfer_dr_stripe_triton.py b/minference/dist_ops/minfer_dr_stripe_triton.py new file mode 100644 index 0000000..c9b6617 --- /dev/null +++ b/minference/dist_ops/minfer_dr_stripe_triton.py @@ -0,0 +1,404 @@ +import os +import torch +import torch.distributed as dist + +from typing import List, Tuple + +from .utils import ( + RingComm, + shuffle_striped_input, recover_striped_output, + get_inner_ring, get_outer_ring +) +from minference.ops.utils import build_index, convert_blockmask +from minference.ops.minference_attn_triton import block_bar_attn_fwd, block_bar_attn_bwd + +def minfer_dr_stripe_triton_forward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + inner_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_comm.rank) + num_inner_steps = len(inner_ring) + + next_k, next_v = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_dr_stripe_triton_forward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + outer_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_comm.rank) + num_outer_steps = len(outer_ring) + + out = None + lse = None + + next_k, next_v = None, None + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + + outer_offset = (outer_rank - outer_step) % num_outer_steps + out, lse = minfer_dr_stripe_triton_forward_inner( + process_group, outer_step, outer_offset, inner_ring, + q, k, v, out, lse, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity, + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + # out = out.to(q.dtype) + # lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def minfer_dr_stripe_triton_backward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + inner_kv_comm = RingComm(process_group, False, inner_ring) + inner_d_kv_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_kv_comm.rank) + num_inner_steps = len(inner_ring) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_kv_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + dq, step_dk, step_dv = block_bar_attn_bwd( + dout, q, k, v, out, dq, None, None, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + deterministic=False, + step=offset, + causal=block_causal, + ) + + # Update dQ, dK, dV + if inner_step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + inner_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + if inner_step + 1 != num_inner_steps: + inner_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = inner_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + inner_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +def minfer_dr_stripe_triton_backward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + outer_kv_comm = RingComm(process_group, False, outer_ring) + outer_d_kv_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_kv_comm.rank) + num_outer_steps = len(outer_ring) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + + outer_offset = (outer_rank - outer_step) % num_outer_steps + step_dq, step_dk, step_dv = minfer_dr_stripe_triton_backward_inner( + process_group, outer_step, outer_offset, inner_ring, + dout, q, k, v, out, softmax_lse, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, granularity, + ) + + if outer_step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + dq += step_dq + outer_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class MInferDRStripeTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + # build index TODO: move convert_indices() into the first step + block_mask, bar_idx, bar_cnt = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + # TODO: remove shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + inner_ring = get_inner_ring(group) + outer_ring = get_outer_ring(group) + out, softmax_lse = minfer_dr_stripe_triton_forward_outer( + group, outer_ring, inner_ring, + q, k, v, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, granularity, + ) + + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.inner_ring = inner_ring + ctx.outer_ring = outer_ring + ctx.layer_idx = layer_idx + + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + dq, dk, dv = minfer_dr_stripe_triton_backward_outer( + ctx.group, ctx.outer_ring, ctx.inner_ring, + dout, q, k, v, out, softmax_lse, ctx.softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, ctx.granularity, + ) + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None + + +def minfer_dr_stripe_triton_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferDRStripeTritonFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_dr_stripe_triton_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferDRStripeTritonFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_dr_stripe_triton_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferDRStripeTritonFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/minfer_dr_striped.py b/minference/dist_ops/minfer_dr_striped.py new file mode 100644 index 0000000..c404c88 --- /dev/null +++ b/minference/dist_ops/minfer_dr_striped.py @@ -0,0 +1,471 @@ +import os +import torch +import torch.distributed as dist + +from typing import List, Tuple, Dict +from .utils import ( + RingComm, + shuffle_striped_input, recover_striped_output, + get_inner_ring, get_outer_ring +) + +from minference.ops.minference_attn_triton import block_bar_attn_fwd +from minference.ops.minference_attn import block_attn_bwd, bar_attn_bwd +from minference.ops.utils import build_index, extract_kv, merge_kv, convert_blockmask + + +def minfer_dr_stripe_forward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + outer_rank: int, + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + inner_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_comm.rank) + num_inner_steps = len(inner_ring) + + next_k, next_v = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + inner_offset = (inner_rank - inner_step) % num_inner_steps + offset = outer_offset * num_inner_steps + inner_offset + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[inner_offset], block_cnt[inner_offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + extract_kv(k, v, bar_k, bar_v, v_idx, v_cnt, step=offset) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_dr_stripe_forward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + granularity: int = 128, +): + outer_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_comm.rank) + num_outer_steps = len(outer_ring) + + out, lse = None, None + next_k, next_v = None, None + inner_block_masks = block_mask.chunk(num_outer_steps, dim=0) + + batch_size, _, num_qo_heads, head_dim = q.shape + max_v_size = v_idx.shape[-1] + bar_k = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + bar_v = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + outer_offset = (outer_rank - outer_step) % num_outer_steps + + out, lse = minfer_dr_stripe_forward_inner( + process_group, + outer_step, outer_offset, outer_rank, inner_ring, + q, k, v, out, lse, layer_idx, softmax_scale, + inner_block_masks[outer_offset], bar_idx, bar_cnt, + v_idx, v_cnt, bar_k, bar_v, + granularity, + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + return out, lse, bar_k, bar_v + + +def minfer_dr_stripe_backward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + outer_rank: int, + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + bar_dk: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_dv: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + inner_kv_comm = RingComm(process_group, False, inner_ring) + inner_d_kv_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_kv_comm.rank) + num_inner_steps = len(inner_ring) + + dk, dv = None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_kv_comm.send_recv_kv(k, v) + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # Update dQ, dK, dV + if inner_step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + inner_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dk += step_dk + dv += step_dv + dq += step_dq + merge_kv(dk, dv, bar_dk, bar_dv, v_idx, v_cnt, step=offset) + if inner_step + 1 != num_inner_steps: + inner_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = inner_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + inner_d_kv_comm.wait() + return dq, next_dk, next_dv + + +def minfer_dr_stripe_backward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_pos: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + outer_kv_comm = RingComm(process_group, False, outer_ring) + outer_d_kv_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_kv_comm.rank) + num_outer_steps = len(outer_ring) + + dq, dk, dv = None, None, None + next_k, next_v = None, None + next_dk, next_dv = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + # Bar Mask + full_bar_cnt = torch.stack([bar_cnt[..., 0], bar_cnt[..., -1]], dim=-1) + dq, bar_dk, bar_dv = bar_attn_bwd( + dout, q, bar_k, bar_v, out, None, None, None, + softmax_lse, softmax_scale, + bar_pos, full_bar_cnt, + granularity=granularity, + deterministic=False, + step=0, + ) + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + outer_offset = (outer_rank - outer_step) % num_outer_steps + + dq, step_dk, step_dv = minfer_dr_stripe_backward_inner( + process_group, outer_step, outer_offset, outer_rank, inner_ring, + dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, v_idx, v_cnt, + dq, bar_dk, bar_dv, + granularity, + ) + + if outer_step == 0: + # TODO: check if float32 is necessary + dk, dv = step_dk, step_dv + else: + outer_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dk += step_dk + dv += step_dv + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class MInferDRStripeFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: softmax_scale = head_dim ** (-0.5) + inner_ring = get_inner_ring(group) + outer_ring = get_outer_ring(group) + + # ---------------------------------------------- + # Index Build + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + granularity=granularity, group=group + ) + + # ---------------------------------------------- + # Shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # ---------------------------------------------- + # Compute + out, softmax_lse, bar_k, bar_v = minfer_dr_stripe_forward_outer( + group, outer_ring, inner_ring, + q, k, v, layer_idx, + softmax_scale, + block_mask, bar_idx, bar_cnt, v_idx, v_cnt, + granularity, + ) + + # ---------------------------------------------- + # Recover + recovered_out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + + # ---------------------------------------------- + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.inner_ring = inner_ring + ctx.outer_ring = outer_ring + ctx.layer_idx = layer_idx + + # ---------------------------------------------- + # Output and Return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v = ctx.saved_tensors + inner_ring = ctx.inner_ring + layer_idx = ctx.layer_idx + group = ctx.group + + # ---------------------------------------------- + # Shuffle + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = minfer_dr_stripe_backward_outer( + ctx.group, ctx.outer_ring, ctx.inner_ring, + dout, q, k, v, out, softmax_lse, + layer_idx, ctx.softmax_scale, + block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v, + granularity=ctx.granularity, + ) + + # ---------------------------------------------- + # Recover + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + return dq, dk, dv, None, None, None, None, None, None, None + + +def minfer_dr_stripe_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferDRStripeFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_dr_stripe_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferDRStripeFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_dr_stripe_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferDRStripeFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/minfer_striped.py b/minference/dist_ops/minfer_striped.py new file mode 100644 index 0000000..6fb6fcd --- /dev/null +++ b/minference/dist_ops/minfer_striped.py @@ -0,0 +1,411 @@ +import os +import sys +import torch +import triton +import torch.distributed as dist +from typing import List, Tuple, Dict + +from .utils import ( + RingComm, + shuffle_striped_input, recover_striped_output, +) +from minference.ops.minference_attn_triton import block_bar_attn_fwd +from minference.ops.minference_attn import block_attn_bwd, bar_attn_bwd +from minference.ops.utils import build_index, convert_blockmask, extract_kv, merge_kv + +if torch.version.hip is None: + original_flags = sys.getdlopenflags() + try: + sys.setdlopenflags(os.RTLD_LAZY | os.RTLD_GLOBAL) + import block_sparse_attn_cuda # type: ignore + from block_sparse_attn.block_sparse_attn_interface import convert_blockmask_row_reverse, convert_blockmask_col_reverse # type: ignore + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + finally: + # Restore original flags for future imports + sys.setdlopenflags(original_flags) + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + +def compute_sr_flops( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + step: int, + granularity: int, + q_len: int, + head_dim: int, + shift: bool, + fwd: bool=True, +): + num_blocks = triton.cdiv(q_len, granularity) + bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] + bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dtype=torch.float32).item() + + total_num_blocks = bh * num_blocks * (num_blocks - 1) / 2 + if step == 0: + total_num_blocks += bh * num_blocks / 2 + elif not shift: + total_num_blocks += bh * num_blocks + + if step == 0: + num_active_blocks = block_mask_offset.sum(dim=-1).sum(dtype=torch.float32).item() - bh * num_blocks / 2 + elif not shift: + num_active_blocks = block_mask_offset.sum(dtype=torch.float32).item() + else: + num_active_blocks = block_mask_offset[..., 1:, :-1].sum(dtype=torch.float32).item() + block_ratio = num_active_blocks / total_num_blocks + bar_ratio = bar_cnt_step / (granularity * total_num_blocks) + sparsity_ratio = 1 - block_ratio - bar_ratio + + block_flops = num_active_blocks * (granularity * granularity) * head_dim * 2 * 2 + bar_flops = bar_cnt_step * granularity * head_dim * 2 * 2 + flops = block_flops + bar_flops + + if not fwd: + flops, block_flops, bar_flops = 2.5 * flops, 2.5 * block_flops, 2.5 * bar_flops + # STEP_DATA_FIELDS = ["block_ratio", "bar_ratio", "sparsity_ratio", "blk_flops", "bar_flops", "flops"] + return block_ratio, bar_ratio, sparsity_ratio, block_flops, bar_flops, flops + + +def compute_sr_by_heads( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + step: int, + granularity: int, + q_len: int, + head_dim: int, + shift: bool, + fwd: bool=True, +): + batch_size, num_heads = block_mask_offset.shape[0], block_mask_offset.shape[1] + num_blocks = triton.cdiv(q_len, granularity) + bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) # [num_qo_heads] + + total_num_blocks = num_blocks * (num_blocks - 1) / 2 + if step == 0: + total_num_blocks += num_blocks / 2 + elif not shift: + total_num_blocks += num_blocks + total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) + + if step == 0: + num_active_blocks = block_mask_offset.sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) - batch_size * num_blocks / 2 + elif not shift: + num_active_blocks = block_mask_offset.sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) + else: + num_active_blocks = block_mask_offset[..., 1:, :-1].sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) + block_ratio_by_heads = num_active_blocks / total_num_blocks_by_heads + bar_ratio_by_heads = bar_cnt_step / total_num_blocks_by_heads / granularity + sparsity_ratio_by_heads = 1 - block_ratio_by_heads - bar_ratio_by_heads + + return sparsity_ratio_by_heads.detach().cpu().numpy().tolist() + + +def sparse_stripe_flash_attn_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group, zigzag=False) + + out, lse = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + batch_size, _, num_qo_heads, head_dim = q.shape + max_v_size = v_idx.shape[-1] + bar_k = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + bar_v = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (comm.rank - step) % comm.world_size + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + extract_kv(k, v, bar_k, bar_v, v_idx, v_cnt, step=offset) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + # lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse, bar_k, bar_v + +def sparse_stripe_flash_attn_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_pos: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + kv_comm = RingComm(process_group, zigzag=False) + d_kv_comm = RingComm(process_group, zigzag=False) + + dq, dk, dv = None, None, None + next_k, next_v = None, None + next_dk, next_dv = None, None + dk_comm_buffer, dv_comm_buffer = None, None + block_mask = convert_blockmask_col_reverse(block_mask, causal=True) + + # Bar Mask + full_bar_cnt = torch.stack([bar_cnt[..., 0], bar_cnt[..., -1]], dim=-1) + dq, bar_dk, bar_dv = bar_attn_bwd( + dout, q, bar_k, bar_v, out, None, None, None, + softmax_lse, softmax_scale, + bar_pos, full_bar_cnt, + granularity=granularity, + deterministic=False, + step=0, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (kv_comm.rank - step) % kv_comm.world_size + + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + converted=True, + ) + + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + dq += step_dq + merge_kv(dk, dv, bar_dk, bar_dv, v_idx, v_cnt, step=offset) + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class SparseStripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + + # Indexing + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + granularity=granularity, group=group + ) + + # Shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # Compute + out, softmax_lse, bar_k, bar_v = sparse_stripe_flash_attn_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, v_idx, v_cnt, + granularity=granularity, + ) + + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # Recover outputs + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + + # Output and Return + if return_softmax: + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v = ctx.saved_tensors + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + # Shuffle + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=granularity, process_group=group) + + # Compute + dq, dk, dv = sparse_stripe_flash_attn_backward( + group, dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v, + granularity=granularity, + ) + + # Recover + dq = recover_striped_output(dq, dim=1, granularity=granularity, process_group=group) + dk = recover_striped_output(dk, dim=1, granularity=granularity, process_group=group) + dv = recover_striped_output(dv, dim=1, granularity=granularity, process_group=group) + return dq, dk, dv, None, None, None, None, None, None, None + +def sparse_stripe_flash_attn_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return SparseStripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def sparse_stripe_flash_attn_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return SparseStripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + +def sparse_stripe_flash_attn_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return SparseStripeFlashAttnFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/minfer_striped_triton.py b/minference/dist_ops/minfer_striped_triton.py new file mode 100644 index 0000000..e3b5eb2 --- /dev/null +++ b/minference/dist_ops/minfer_striped_triton.py @@ -0,0 +1,284 @@ +import os +import torch +import torch.distributed as dist +from typing import List, Tuple, Dict + +from .utils import ( + RingComm, + shuffle_striped_input, recover_striped_output, +) +from minference.ops.utils import build_index, convert_blockmask +from minference.ops.minference_attn_triton import block_bar_attn_fwd, block_bar_attn_bwd + + +def sparse_stripe_flash_attn_triton_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group) + out, lse = None, None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (comm.rank - step) % comm.world_size + + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + return out, lse + + +def sparse_stripe_flash_attn_triton_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (kv_comm.rank - step) % kv_comm.world_size + + dq, step_dk, step_dv = block_bar_attn_bwd( + dout, q, k, v, out, dq, None, None, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + deterministic=False, + step=offset, + causal=block_causal, + ) + + # Update dQ, dK, dV + if step == 0: + dk = step_dk + dv = step_dv + else: + d_kv_comm.wait() + + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class SparseStripeFlashAttnTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + # built block_idx: [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + block_mask, bar_idx, bar_cnt, _, _, _ = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # slash attn + out, softmax_lse = sparse_stripe_flash_attn_triton_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity=granularity, + ) + + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + layer_idx = ctx.layer_idx + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + dq, dk, dv = sparse_stripe_flash_attn_triton_backward( + ctx.group, dout, q, k, v, out, softmax_lse, + layer_idx, ctx.softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity=ctx.granularity, + ) + + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + return dq, dk, dv, None, None, None, None, None, None, None + + +def sparse_stripe_flash_attn_triton_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return SparseStripeFlashAttnTritonFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def sparse_stripe_flash_attn_triton_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return SparseStripeFlashAttnTritonFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def sparse_stripe_flash_attn_triton_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return SparseStripeFlashAttnTritonFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/minfer_zigzag.py b/minference/dist_ops/minfer_zigzag.py new file mode 100644 index 0000000..88292aa --- /dev/null +++ b/minference/dist_ops/minfer_zigzag.py @@ -0,0 +1,331 @@ +import os +import torch +import triton +import torch.distributed as dist +from typing import List, Tuple, Dict + +from .utils import ( + RingComm, + shuffle_zigzag_input, recover_zigzag_output, +) +from minference.ops.utils import build_index, convert_blockmask +from minference.ops.minference_attn_triton import block_bar_attn_fwd +from minference.ops.minference_attn import block_attn_bwd, bar_attn_bwd + +def minfer_zigzag_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group, zigzag=True) + ring_list = comm.ring_list + ring_index = ring_list.index(comm.rank) + + out, lse = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (ring_index - step) % comm.world_size + + # ---------------------------------------------- + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + return out, lse + + +def minfer_zigzag_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + ring_list = kv_comm.ring_list + ring_index = ring_list.index(kv_comm.rank) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (ring_index - step) % kv_comm.world_size + + # ---------------------------------------------- + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # ---------------------------------------------- + # Bar Mask + step_dq, step_dk, step_dv = bar_attn_bwd( + dout, q, k, v, out, step_dq, step_dk, step_dv, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, + granularity=granularity, + deterministic=False, + step=offset, + ) + + # ---------------------------------------------- + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dq += step_dq + dk += step_dk + dv += step_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class MInferZigzagAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + + # ------------------------------------------------------------------ + # Index Build + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + stripe_transform=False, + zigzag_transform=True, + granularity=granularity, group=group + ) + + # ---------------------------------------------- + # Shuffle + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + out, softmax_lse = minfer_zigzag_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, + granularity=granularity, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + + # ---------------------------------------------- + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # Output and Return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + # ---------------------------------------------- + # Shuffle + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = minfer_zigzag_backward( + group, dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, + granularity=granularity, + ) + + # ---------------------------------------------- + # Recover + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + + return dq, dk, dv, None, None, None, None, None, None, None + + +def minfer_zigzag_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferZigzagAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_zigzag_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferZigzagAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_zigzag_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return MInferZigzagAttnFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/moba_zigzag.py b/minference/dist_ops/moba_zigzag.py new file mode 100644 index 0000000..e8834b0 --- /dev/null +++ b/minference/dist_ops/moba_zigzag.py @@ -0,0 +1,1066 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import torch +import torch.distributed as dist + +from einops import rearrange +from typing import List, Tuple, Dict +from time import perf_counter + +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) + + +from .utils import ( + RingComm, update_out_and_lse, + recover_zigzag_output, get_default_args, +) +from .op_utils.moba_utils import ( + shuffle_input_all, shuffle_input_only, compute_moba_gate +) + + +def moba_zigzag_attn_fwd_step( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # [S, H, D] + step: int, + causal: bool, + # q_seq_offsets: torch.Tensor, + num_q_blocks: int, + k_seq_offsets: torch.Tensor, + + gate_mask: torch.Tensor, # [num_filtered_chunk, num_head, seq_len] + cu_chunk: torch.Tensor, + filtered_chunk_indices: torch.Tensor, + num_filtered_chunk: int, + chunk_to_batch: torch.Tensor, + moba_chunk_size: int, + moba_topk: int, + + softmax_scale, + dropout_p=0, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +) -> Tuple[torch.Tensor, torch.Tensor, int]: + _, _, seq_len = gate_mask.shape + q_block_seq_len, num_head, head_dim = q.shape + k_block_seq_len, k_num_head, _ = k.shape + if num_head > k_num_head: + k = torch.repeat_interleave(k, num_head // k_num_head, dim=1) + v = torch.repeat_interleave(v, num_head // k_num_head, dim=1) + + block_seq_len = q_block_seq_len // num_q_blocks + + # assumption: block_seq_len is divisible by moba_chunk_size + assert (block_seq_len % moba_chunk_size == 0), "block_seq_len should be divisible by moba_chunk_size" + + kv = torch.stack((k, v), dim=1) + k_seq_offset_list = [k_seq_offsets[i].detach().cpu().item() for i in range(len(k_seq_offsets))] + filtered_kv_indices = torch.arange( + 0, min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[0], + device=k.device, dtype=torch.int32 + ) + kv_chunk_indices = torch.arange( + k_seq_offset_list[0], min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, device=k.device, dtype=torch.int32 + ) + if len(k_seq_offset_list) > 1: + filtered_kv_indices = torch.cat([ + filtered_kv_indices, + torch.arange( + block_seq_len, + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[1] + block_seq_len, + device=k.device, dtype=torch.int32 + ) + ]) + kv_chunk_indices = torch.cat([ + kv_chunk_indices, + torch.arange( + k_seq_offset_list[1], + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, + device=k.device, dtype=torch.int32 + ) + ]) + filtered_kv = kv.index_select(0, filtered_kv_indices) + kv_chunk_indices = kv_chunk_indices // moba_chunk_size + num_filtered_kv_chunks = len(kv_chunk_indices) + + q_indices = torch.arange( + 0 if num_q_blocks == 2 else block_seq_len, 2 * block_seq_len, + device=q.device, dtype=torch.int32 + ) + + # varlen trick: combining all q index that needs moba attn + # the result will be like [ C0H0 ][ C0H1 ][ C0H2 ][ ... ][ CnHm ] + gate_mask_q = gate_mask.index_select(0, kv_chunk_indices) + gate_mask_q = gate_mask_q.index_select(2, q_indices) # we need to know which part(s) of the two query blocks should be activated + + moba_q_indices = gate_mask_q.reshape(gate_mask_q.shape[0], -1).nonzero(as_tuple=True)[-1] + moba_seqlen_q = gate_mask_q.sum(dim=-1).flatten() + + # ----------------------------------------------------------- + # select all q that needs moba attn based on the moba_q_indices + moba_q = rearrange(q, "s h d -> ( h s ) d") + moba_q = moba_q.index_select(0, moba_q_indices) # [ selected_HS, D ] + moba_q = moba_q.unsqueeze(1) + + # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q + moba_q_sh_indices = moba_q_indices % q_block_seq_len * num_head + moba_q_indices // q_block_seq_len + + """ prepare moba kv """ + # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q + + # cut off zero experts + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + zero_expert_count = q_zero_mask.sum() + # only keep the kv that has q select > 0 + if zero_expert_count > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + + # moba cu_seqlen for flash attn + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + # ----------------------------------------------------------------------------------- + # here `x` only stands for a dimension (stack dimension for KV) + moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # [H, K_S, 2, D ] + + moba_kv = moba_kv.split(moba_chunk_size, dim=1) # tuple of (num_selected_chunks) elements with shape [H, chunk_size, 2, D] + moba_kv = torch.cat(moba_kv, dim=0) # [H x num_selected_chunks, chunk_size, 2, D ] after split + + # The transformation is aimed for masking out by valid_expert_mask where the mask selects elements along (H x num_selected_chunks) dimension + if zero_expert_count > 0: + assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count + moba_kv = moba_kv[ + valid_expert_mask + ] # cut off zero Q expert from kv , or the grad may be nan + + moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2) # [H x num_selected_chunks x chunk_size, 2, 1, D] + moba_cu_seqlen_kv = ( + torch.arange( + 0, num_filtered_kv_chunks * num_head + 1 - zero_expert_count, + dtype=torch.int32, device=q.device, + ) * moba_chunk_size + ) + + # Shape check + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"moba_cu_seqlen_kv.shape != moba_cu_seqlen_q.shape {moba_cu_seqlen_kv.shape} != {moba_cu_seqlen_q.shape}" + + softmax_scale = softmax_scale = head_dim ** (-0.5) + + self_attn_cu_seqlen = [0] + [moba_chunk_size] * (q_block_seq_len // moba_chunk_size) + if q_block_seq_len % moba_chunk_size != 0: + self_attn_cu_seqlen.append(q_block_seq_len % moba_chunk_size) + self_attn_cu_seqlen = torch.tensor(self_attn_cu_seqlen, device=q.device, dtype=torch.int32) + self_attn_cu_seqlen = self_attn_cu_seqlen.cumsum(dim=0, dtype=torch.int32) + + # ----------------------------------------------------------------------------------- + # self attn + if causal: + # out, softmax_lse, S_dmask, rng_state + self_attn_out_sh, self_attn_lse_hs, _, _ = ( + _flash_attn_varlen_forward( + q=q, k=k, v=v, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=q_block_seq_len, + max_seqlen_k=k_block_seq_len, + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + ) + ) + else: + # self_attn_out_sh, self_attn_lse_hs = None, None + self_attn_out_sh = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + self_attn_lse_hs = torch.zeros((num_head, q_block_seq_len), device=q.device, dtype=torch.float32) + (-float('inf')) + + + # moba attn + # moba_attn_lse_hs - [1, num_nonzero_elems] + if moba_q.shape[0] > 0: + # out, softmax_lse, S_dmask, rng_state + moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=q_block_seq_len, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + else: + moba_attn_lse_hs = torch.zeros((1, moba_q.shape[0]), device=q.device, dtype=torch.float32) + (-float('inf')) + + # ----------------------------------------------------------------------------------- + # If no queries need to be computed with the current KV chunk and no causal attention is needed, return None to skip the output update + if not causal and moba_q.shape[0] == 0: + return None, None, 0, torch.zeros((num_head,), device=q.device, dtype=torch.float32) + + # ----------------------------------------------------------------------------------- + # Processing output and lse + # output buffer [S, H, D], same shape as q + output = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + # flatten vS & H for index ops + output_2d = output.view(-1, q.shape[2]) + + # -------------------------------------------------- + moba_attn_lse: torch.Tensor = moba_attn_lse_hs.t().contiguous() # [ num_nonzero_elems, 1 ] + self_attn_lse_sh = self_attn_lse_hs.t().contiguous() # [q_S, H] + + # calc mixed_lse + # minus max lse to avoid exp explosion + max_lse_1d = self_attn_lse_sh.view(-1) # [ vS ] + max_lse_1d = max_lse_1d.index_reduce( + 0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax" + ) + self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh) + moba_attn_lse = ( + moba_attn_lse.view(-1) + .sub(max_lse_1d.index_select(0, moba_q_sh_indices)) + .reshape_as(moba_attn_lse) + ) + + # -------------------------------------------------- + # Build mixed attn lse + mixed_attn_se_sh = self_attn_lse_sh.exp() if causal else torch.zeros_like(self_attn_lse_sh) + moba_attn_se = moba_attn_lse.exp() if moba_q.shape[0] > 0 else torch.zeros_like(moba_attn_lse) + + # index_add_: converting elements from 1D tensor (num_nonzero_elems) to matrices (HS) + # Now, mixed_attn_se_sh is the sum of LSE of self attn and LSE of moba attn (including multiple LSEs corresponding to the same q token but in different HS positions) + mixed_attn_se_sh.view(-1).index_add_( + 0, moba_q_sh_indices, moba_attn_se.view(-1) + ) + mixed_attn_lse_sh = mixed_attn_se_sh.log() + + # ---------------------------------------------------- + # Compute factor of self-attention and add to output + if causal: + factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [ vS, H ] + self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1) + output_2d += self_attn_out_sh.reshape_as(output_2d) + + # add moba output + # ---------------------------------------------------- + # Compute factor of moba-attention and add to output + if moba_q.shape[0] > 0: + mixed_attn_lse = ( + mixed_attn_lse_sh.view(-1) + .index_select(0, moba_q_sh_indices) + .view_as(moba_attn_lse) + ) + factor = (moba_attn_lse - mixed_attn_lse).exp() # [ vS, H ] + moba_attn_out = moba_attn_out * factor.unsqueeze(-1) + + raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1]) + output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out.to(output_2d.dtype)) + + output = output.to(q.dtype) + + # add back max lse + mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh) + return output, mixed_attn_lse_sh.t() + +def moba_zigzag_attn_fwd( + process_group, + q: torch.Tensor, # [S, H, D] + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, # sequence offsets for Q + layer_idx: int, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[0] // 2 + seq_len, num_q_heads, head_dim = q.shape + + out, lse = None, None + next_k, next_v = None, None + + kv_seq_offsets = torch.clone(seq_offsets) + next_kv_seq_offsets = None + + def fwd_step( + q_, k_, v_, step_, causal_, + # q_seq_offsets, + num_q_blocks, + k_seq_offsets + ): + return moba_zigzag_attn_fwd_step( + q_, k_, v_, + step_, + causal_, + num_q_blocks, + k_seq_offsets, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale, + dropout_p, + window_size, + alibi_slopes, + deterministic, + ) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + # when step < N-1, do the ring-communication to get KV to be used in the next round + next_k, next_v, next_kv_seq_offsets = comm.send_recv_kv_offsets(k, v, kv_seq_offsets) + + if step == 0: + # Do softmax(QK^T / sqrt(d_k))V on the currently hold K and V + # and record the output and the LSE + block_out, block_lse = fwd_step( + q, k, v, step, causal_=True, + num_q_blocks=2, + k_seq_offsets=kv_seq_offsets, + ) + + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.revert_rank: + k0 = k[:block_seq_len] + v0 = v[:block_seq_len] + block_out, block_lse = fwd_step( + q, k0, v0, step, causal_=False, + num_q_blocks=2, + k_seq_offsets=kv_seq_offsets[0:1], + ) + + if block_out is not None: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + q1 = q[block_seq_len:] + block_out, block_lse = fwd_step( + q1, k, v, step, causal_=False, + num_q_blocks=1, + k_seq_offsets=kv_seq_offsets, + ) + + if block_out is not None: + out, lse = update_out_and_lse( + out, lse, + block_out, + block_lse, + slice_=(slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v, kv_seq_offsets = next_k, next_v, next_kv_seq_offsets + + out = out.to(q.dtype) # [S, H, D] + lse = lse.squeeze(dim=-1).transpose(0, 1) # [H, S] + return out, lse + +def moba_zigzag_attn_bwd_step( + step: int, + + dout, # [blk_S, H, D] + out, # [blk_S, H, D] + causal: bool, + + q: torch.Tensor, # [blk_S, H, D] + k: torch.Tensor, # [blk_S, H, D] + v: torch.Tensor, # [blk_S, H, D] + # dq: torch.Tensor, + # dk: torch.Tensor, + # dv: torch.Tensor, # [blk_S, H, D] + + softmax_lse: torch.Tensor, # [H, blk_S] + num_q_blocks: int, + k_seq_offsets: torch.Tensor, + layer_idx: int, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch: torch.Tensor, + moba_chunk_size: int, + moba_topk: int, + + softmax_scale, + dropout_p=0, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +) -> Tuple[torch.Tensor, torch.Tensor, int]: + _, _, seq_len = gate_mask.shape + q_block_seq_len, num_head, head_dim = q.shape + k_block_seq_len, k_num_head, _ = k.shape + if num_head > k_num_head: + k = torch.repeat_interleave(k, num_head // k_num_head, dim=1) + v = torch.repeat_interleave(v, num_head // k_num_head, dim=1) + + block_seq_len = q_block_seq_len // num_q_blocks + + # assumption: block_seq_len is divisible by moba_chunk_size + assert (block_seq_len % moba_chunk_size == 0), "block_seq_len should be divisible by moba_chunk_size" + + # ----------------------------------------------------------------------------------- + dq = torch.zeros_like(q, dtype=q.dtype) + dk = torch.zeros_like(k, dtype=k.dtype) + dv = torch.zeros_like(v, dtype=v.dtype) + + kv = torch.stack((k, v), dim=1) + dkv = torch.stack((dk, dv), dim=1) + # ----------------------------------------------------------------------------------- + + + k_seq_offset_list = [k_seq_offsets[i].detach().cpu().item() for i in range(len(k_seq_offsets))] + filtered_kv_indices = torch.arange( + 0, min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[0], + device=k.device, dtype=torch.int32 + ) + kv_chunk_indices = torch.arange( + k_seq_offset_list[0], min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, device=k.device, dtype=torch.int32 + ) + if len(k_seq_offset_list) > 1: + filtered_kv_indices = torch.cat([ + filtered_kv_indices, + torch.arange( + block_seq_len, + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[1] + block_seq_len, + device=k.device, dtype=torch.int32 + ) + ]) + kv_chunk_indices = torch.cat([ + kv_chunk_indices, + torch.arange( + k_seq_offset_list[1], + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, + device=k.device, dtype=torch.int32 + ) + ]) + filtered_kv = kv.index_select(0, filtered_kv_indices) + filtered_dkv = dkv.index_select(0, filtered_kv_indices) + + kv_chunk_indices = kv_chunk_indices // moba_chunk_size + num_filtered_kv_chunks = len(kv_chunk_indices) + + q_indices = torch.arange( + 0 if num_q_blocks == 2 else block_seq_len, 2 * block_seq_len, + device=q.device, dtype=torch.int32 + ) + + # varlen trick: combining all q index that needs moba attn + # the result will be like [ C0H0 ][ C0H1 ][ C0H2 ][ ... ][ CnHm ] + gate_mask_q = gate_mask.index_select(0, kv_chunk_indices) + gate_mask_q = gate_mask_q.index_select(2, q_indices) # we need to know which part(s) of the two query blocks should be activated + + # equivalent to einops.rearrange(q, "n h s -> n (h s)"). ([s] [s] ... [s] for h times) + # [HS indices] * N (total size: all non-zero elements in HS dimension, potentially repeat) + # [num_selected_chunks, HS indices of non-zero elements] + # gate_mask has been filtered by q_indices. If we still need use gate_mask_q for indexing, it should be offset by block_seq_len if num_q_blocks == 1 + # + (0 if num_q_blocks == 2 else block_seq_len) + moba_q_indices = gate_mask_q.reshape(gate_mask_q.shape[0], -1).nonzero(as_tuple=True)[-1] + + # moba_seqlen_q indicates that how many q chunks are selected for each kv chunk - head + # moba_seqlen_q has shape (num_selecte_chunks * num_heads, ) => varlen_forward computes attention by (num_selecte_chunks * num_heads) times + moba_seqlen_q = gate_mask_q.sum(dim=-1).flatten() + + # ----------------------------------------------------------- + # select all q that needs moba attn based on the moba_q_indices + moba_q = rearrange(q, "s h d -> ( h s ) d") + moba_dq = rearrange(dq, "s h d -> ( h s ) d") + + moba_q = moba_q.index_select(0, moba_q_indices) # [ selected_HS, D ] + moba_dq = moba_dq.index_select(0, moba_q_indices) # [ selected_HS, D ] + + # [ selected_S, 1, D ] (pseudo head dim for flash attn) + moba_q = moba_q.unsqueeze(1) + moba_dq = moba_dq.unsqueeze(1) + + # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q + # note that original q has shape (S, H, D) while moba_q_indices is based on (H S) + # Ignoring D, q has the flattend form like [H] [H] ... [H] for S times + # => moba_q_sh_indices is the index of each token in the original q tensor + moba_q_sh_indices = moba_q_indices % q_block_seq_len * num_head + moba_q_indices // q_block_seq_len + + + """ prepare moba kv """ + # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q + # cut off zero experts + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + zero_expert_count = q_zero_mask.sum() + # only keep the kv that has q select > 0 + if zero_expert_count > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + + # moba cu_seqlen for flash attn + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + + # ------------------------------ + # Select dout and output + d_moba_out = ( + # [num_non-zero_elements, D] + dout.view(-1, head_dim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + moba_out = ( + # [num_non-zero_elements, D] + out.view(-1, head_dim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + + + + # ----------------------------------------------------------------------------------- + # here `x` only stands for a dimension (stack dimension for KV) + moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # [H, K_S, 2, D ] + moba_dkv = rearrange(filtered_dkv, "s x h d -> h s x d") # [H, K_S, 2, D ] + + moba_kv = moba_kv.split(moba_chunk_size, dim=1) # tuple of (num_selected_chunks) elements with shape [H, chunk_size, 2, D] + moba_kv = torch.cat(moba_kv, dim=0) # [H x num_selected_chunks, chunk_size, 2, D ] after split + moba_dkv = torch.cat(moba_dkv.split(moba_chunk_size, dim=1), dim=0) # [H x num_selected_chunks, chunk_size, 2, D ] after split + + # The transformation is aimed for masking out by valid_expert_mask where the mask selects elements along (H x num_selected_chunks) dimension + if zero_expert_count > 0: + assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count + + # cut off zero Q expert from kv , or the grad may be nan + moba_kv = moba_kv[valid_expert_mask] + moba_dkv = moba_dkv[valid_expert_mask] + + moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2) # [H x num_selected_chunks x chunk_size, 2, 1, D] + moba_dkv = moba_dkv.flatten(start_dim=0, end_dim=1).unsqueeze(2) # [H x num_selected_chunks x chunk_size, 2, 1, D] + + moba_cu_seqlen_kv = ( + torch.arange( + 0, num_filtered_kv_chunks * num_head + 1 - zero_expert_count, + dtype=torch.int32, device=q.device, + ) * moba_chunk_size + ) + + # Shape check + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"moba_cu_seqlen_kv.shape != moba_cu_seqlen_q.shape {moba_cu_seqlen_kv.shape} != {moba_cu_seqlen_q.shape}" + + + self_attn_cu_seqlen = [0] + [moba_chunk_size] * (q_block_seq_len // moba_chunk_size) + if q_block_seq_len % moba_chunk_size != 0: + self_attn_cu_seqlen.append(q_block_seq_len % moba_chunk_size) + self_attn_cu_seqlen = torch.tensor(self_attn_cu_seqlen, device=q.device, dtype=torch.int32) + self_attn_cu_seqlen = self_attn_cu_seqlen.cumsum(dim=0, dtype=torch.int32) + + # ----------------------------------------------------------------------------------- + # self attn + if causal: + dq_, dk_, dv_ = torch.empty_like(dq), torch.empty_like(dkv[:, 0]), torch.empty_like(dkv[:, 1]) + _flash_attn_varlen_backward( + dout=dout, out=out, + q=q, k=k, v=v, + dq=dq_, dk=dk_, dv=dv_, + softmax_lse=softmax_lse.contiguous(), + + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=q_block_seq_len, + max_seqlen_k=k_block_seq_len, + + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + deterministic=deterministic, + softcap=0.0, + ) + dq, dkv[:, 0], dkv[:, 1] = dq + dq_, dk_ + dkv[:, 0], dv_ + dkv[:, 1] + + if moba_q.shape[0] > 0: + softmax_lse_sh = rearrange(softmax_lse.contiguous(), "h s -> (s h)") + moba_attn_lse = ( + # [1, num_non-zero_elements] + softmax_lse_sh.index_select(0, moba_q_sh_indices).view(1, -1) + ) + + moba_dq_, moba_dk_, moba_dv_ = torch.empty_like(moba_q), torch.empty_like(moba_kv[:, 0]), torch.empty_like(moba_kv[:, 1]) + _flash_attn_varlen_backward( + dout=d_moba_out, out=moba_out, + q=moba_q, k=moba_kv[:, 0], v=moba_kv[:, 1], + dq=moba_dq_, dk=moba_dk_, dv=moba_dv_, + softmax_lse=moba_attn_lse, + + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + + max_seqlen_q=q_block_seq_len, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + + causal=False, + dropout_p=0.0, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + deterministic=deterministic, + softcap=0.0, + ) + + dq.view(-1, q.shape[-1]).index_add_( + 0, moba_q_sh_indices, moba_dq.view(-1, head_dim).to(dq.dtype) + ) + moba_dkv[:, 0] = moba_dkv[:, 0] + moba_dk_ + moba_dkv[:, 1] = moba_dkv[:, 1] + moba_dv_ + + # ------------------------------------------------------------------------------------ + # Backpropagate moba_dkv to dk and dv + moba_dkv = moba_dkv.squeeze(2) # [H x num_selected_chunks x chunk_size, 2, D] + moba_dkv = moba_dkv.unflatten(0, (-1, moba_chunk_size)) # [H x num_selected_chunks, chunk_size, 2, D] + + if zero_expert_count > 0: + full_moba_dkv = torch.zeros( + (moba_dkv.shape[0] + zero_expert_count, moba_chunk_size, 2, head_dim), + dtype=moba_dkv.dtype, device=moba_dkv.device + ) + full_moba_dkv[valid_expert_mask] = moba_dkv + moba_dkv = full_moba_dkv # [H x num_selected_chunks, chunk_size, 2, D] + moba_dkv = moba_dkv.split(num_head, dim=0) # [H, num_selected_chunks, chunk_size, 2, D] + moba_dkv = torch.cat(moba_dkv, dim=1) # [H, num_selected_chunks x chunk_size, 2, D] + + filtered_dkv = rearrange(moba_dkv, "h s x d -> s x h d") + dkv.index_add_( + 0, filtered_kv_indices, filtered_dkv # [K_S, 2, H, D] + ) + + if num_head > k_num_head: + num_kv_replicas = num_head // k_num_head + dkv_reshaped = dkv.view(-1, 2, k_num_head, num_kv_replicas, head_dim) + dkv = dkv_reshaped.sum(dim=3) + + return dq, dkv[:, 0], dkv[:, 1] + +def moba_zigzag_attn_bwd( + process_group, + dout, # [blk_S, H, D] + q: torch.Tensor, # [blk_S, H, D] + k: torch.Tensor, # [blk_S, H, D] + v: torch.Tensor, # [blk_S, H, D] + out, # [blk_S, H, D] + softmax_lse, # [H, blk_S] + + seq_offsets: torch.Tensor, # sequence offsets for Q + layer_idx: int, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + + kv_seq_offsets = torch.clone(seq_offsets) + seq_len, num_q_heads, head_dim = q.shape + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=0)[1] + q1 = q.chunk(2, dim=0)[1] + out1 = out.chunk(2, dim=0)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=1)[1].contiguous() + block_seq_len = q.shape[0] // 2 + + def backward( + step, + dout_, q_, k_, v_, out_, + k_seq_offsets, + softmax_lse_, + causal + ): + seqlen_q = q_.shape[0] + seqlen_kv = k_.shape[0] + + params = get_default_args(moba_zigzag_attn_bwd_step).copy() + params.update( + { + "step": step, + "causal": causal, + "dout": dout_, + "out": out_, + + "q": q_, + "k": k_, + "v": v_, + "softmax_lse": softmax_lse_, + + "num_q_blocks": 1 if seqlen_q == block_seq_len else 2, + "k_seq_offsets": k_seq_offsets, + "layer_idx": layer_idx, + + "gate_mask": gate_mask, + "cu_chunk": cu_chunk, + "filtered_chunk_indices": filtered_chunk_indices, + "num_filtered_chunk": num_filtered_chunk, + "chunk_to_batch": chunk_to_batch, + "moba_chunk_size": moba_chunk_size, + "moba_topk": moba_topk, + + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + return moba_zigzag_attn_bwd_step(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + # next_k, next_v = kv_comm.send_recv_kv(k, v) + next_k, next_v, next_kv_seq_offsets = kv_comm.send_recv_kv_offsets(k, v, kv_seq_offsets) + + if step == 0: + dq_buffer, dk_buffer, dv_buffer = backward( + step, + dout, q, k, v, out, + kv_seq_offsets, softmax_lse, causal=True + ) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:block_seq_len] + v0 = v[:block_seq_len] + dq_buffer, dk_buffer, dv_buffer = backward( + step, + dout, q, k0, v0, out, + kv_seq_offsets[0:1], softmax_lse, causal=False + ) + dq += dq_buffer + else: + dq_buffer, dk_buffer, dv_buffer = backward( + step, + dout1, q1, k, v, out1, + kv_seq_offsets, softmax_lse1, causal=False + ) + + # use the first half in dq_buffer. + dq[block_seq_len:] += dq_buffer + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dk[:block_seq_len] += dk_buffer + dv[:block_seq_len] += dv_buffer + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v, kv_seq_offsets = next_k, next_v, next_kv_seq_offsets + + # the finally received dk and dv will be the same as the first dk and dv (corresponding to local Q) + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class MoBAZigzagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [batch * seq_block_len, n_heads, head_dim] + k: torch.Tensor, + v: torch.Tensor, + seq_offset: torch.Tensor, + layer_idx, + dropout_p, + softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + # print(f"Rank {dist.get_rank()} | forward | q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}") + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + ( + gate_mask, cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch + ) = compute_moba_gate( + q, k, v, + seq_offset, + cu_seqlens, + moba_chunk_size, + moba_topk, + ) + + # gate_mask needs to be shuffled as it is coupled with q + q, seq_offsets, gate_mask = shuffle_input_all( + to_send=q, gate_mask=gate_mask, seq_offset=seq_offset, + process_group=group + ) + k = shuffle_input_only(to_send=k, process_group=group) + v = shuffle_input_only(to_send=v, process_group=group) + + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = moba_zigzag_attn_fwd( + group, + q, k, v, + seq_offsets, # sequence offsets for Q + layer_idx, + + gate_mask, cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + # this should be out_padded + ctx.save_for_backward( + q, k, v, out, softmax_lse, seq_offsets, + gate_mask, cu_chunk, filtered_chunk_indices, + chunk_to_batch + ) + ctx.num_filtered_chunk = num_filtered_chunk + ctx.moba_chunk_size = moba_chunk_size + ctx.moba_topk = moba_topk + + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + ctx.layer_idx = layer_idx + + + out = recover_zigzag_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_input_only(to_send=dout, process_group=ctx.group) + ( + q, k, v, out, + softmax_lse, # [n_heads, seq_block_len] + seq_offsets, + gate_mask, cu_chunk, filtered_chunk_indices, + chunk_to_batch + ) = ctx.saved_tensors + + num_filtered_chunk = ctx.num_filtered_chunk + moba_chunk_size = ctx.moba_chunk_size + moba_topk = ctx.moba_topk + + dq, dk, dv = moba_zigzag_attn_bwd( + ctx.group, + dout, + q, k, v, + out, + softmax_lse, + + seq_offsets, + ctx.layer_idx, + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + + dq = recover_zigzag_output(dq, ctx.group) + dk = recover_zigzag_output(dk, ctx.group) + dv = recover_zigzag_output(dv, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return MoBAZigzagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return MoBAZigzagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return MoBAZigzagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/op_utils/__init__.py b/minference/dist_ops/op_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minference/dist_ops/op_utils/moba_utils.py b/minference/dist_ops/op_utils/moba_utils.py new file mode 100644 index 0000000..7425646 --- /dev/null +++ b/minference/dist_ops/op_utils/moba_utils.py @@ -0,0 +1,636 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import math +import torch +import inspect +import operator +import contextlib +import pandas as pd +import torch.nn.functional as F +import torch.distributed as dist + +from functools import reduce, cache, lru_cache +from typing import Optional, Tuple, List, Dict + +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +# copy from megatron/core/utils.py +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + +def check_nan_inf( + out, lse, block_out, block_lse, phase_prefix: str, postfix: str +): + if (not torch.isnan(block_out).any()) and (not torch.isnan(block_lse).any()): + if torch.isnan(out).any(): + print(f"{phase_prefix}nan in out ({postfix})") + if torch.isinf(out).any(): + print(f"{phase_prefix}inf in out ({postfix})") + + if torch.isnan(lse).any(): + print(f"{phase_prefix}nan in lse ({postfix})") + if torch.isinf(lse).any(): + print(f"{phase_prefix}inf in lse ({postfix})") + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + parts = self.world_size // 2 + self.ring_list = [] + for i in range(parts): + self.ring_list.extend([i, self.world_size - i - 1]) + + self.revert_rank = self.ring_list.index(self.rank) + + offset = ((dist.get_rank() // self.world_size) * self.world_size) + self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset + self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + + for req in self._reqs: + req.wait() + + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + def send_recv_kv_offsets( + self, + k: torch.Tensor, + v: torch.Tensor, + kv_seq_offsets: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + kv_seq_offsets_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + next_kv_seq_offsets = self.send_recv(kv_seq_offsets, kv_seq_offsets_buffer) + + self.commit() + return next_k, next_v, next_kv_seq_offsets + + +def shuffle_input_all( + to_send: torch.Tensor, # [S, H, D] + seq_offset: torch.Tensor, # [2] + gate_mask: torch.Tensor = None, # [num_chunks, H, S] + process_group: dist.ProcessGroup = None + ): + orig_ndim = to_send.ndim + if orig_ndim == 3: to_send = to_send.unsqueeze(0) + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + block_seq_len = to_send.shape[1] // 2 + + seq_offset_val = seq_offset.detach().cpu().item() + seq_offsets = torch.Tensor([seq_offset_val, seq_offset_val + block_seq_len]).to(to_send.device) + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + to_send_gate_mask = torch.zeros_like(gate_mask) + to_send_offset = seq_offsets[1] + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + to_send_slice = to_send[:, block_seq_len:].contiguous() + to_send_gate_mask_slice = gate_mask[..., block_seq_len:].contiguous() + + res = torch.zeros_like(to_send_slice) + res_gate_mask = torch.zeros_like(to_send_gate_mask_slice) + res_offset= torch.zeros_like(to_send_offset) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + send_gate_mask_op = dist.P2POp( + dist.isend, to_send_gate_mask_slice, src_rank, group=process_group + ) + send_offset_op = dist.P2POp( + dist.isend, to_send_offset, src_rank, group=process_group + ) + _ops.append(send_op) + _ops.append(send_gate_mask_op) + _ops.append(send_offset_op) + + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + recv_gate_mask_op = dist.P2POp( + dist.irecv, res_gate_mask, src_rank, group=process_group + ) + recv_offset_op = dist.P2POp( + dist.irecv, res_offset, src_rank, group=process_group + ) + _ops.append(recv_op) + _ops.append(recv_gate_mask_op) + _ops.append(recv_offset_op) + + # response = dist.dist.batch_isend_irecv(_ops) + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + + to_send_gate_mask[..., block_seq_len:] = gate_mask[..., :block_seq_len] + to_send_gate_mask[..., :block_seq_len] = res_gate_mask + + seq_offsets[1] = seq_offsets[0] + seq_offsets[0] = res_offset + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + + to_send_gate_mask[..., :block_seq_len] = gate_mask[..., :block_seq_len] + to_send_gate_mask[..., block_seq_len:] = res_gate_mask + + seq_offsets[1] = res_offset + + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + return ( + to_send_f if orig_ndim != 3 else to_send_f.squeeze(0), + seq_offsets, + to_send_gate_mask, + ) + + +def shuffle_input_only( + to_send: torch.Tensor, # [S, H, D] + process_group: dist.ProcessGroup = None + ) -> torch.Tensor: + orig_ndim = to_send.ndim + if orig_ndim == 3: to_send = to_send.unsqueeze(0) + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + block_seq_len = to_send.shape[1] // 2 + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + to_send_slice = to_send[:, block_seq_len:].contiguous() + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + _ops.append(send_op) + + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + _ops.append(recv_op) + + # response = dist.dist.batch_isend_irecv(_ops) + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + return to_send_f if orig_ndim != 3 else to_send_f.squeeze(0) + + +def recover_output( + to_send: torch.Tensor, # [S, H, D] + process_group: dist.ProcessGroup = None + ): + orig_ndim = to_send.ndim + if orig_ndim == 3: to_send = to_send.unsqueeze(0) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + to_send_f = torch.zeros_like(to_send) + + block_seq_len = to_send.shape[1] // 2 + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send[:, :block_seq_len, ...].contiguous() + else: + to_send_slice = to_send[:, block_seq_len:, ...].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[:, :block_seq_len] = to_send[:, block_seq_len:, ...] + to_send_f[:, block_seq_len:] = res + else: + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len, ...] + to_send_f[:, block_seq_len:] = res + + return to_send_f.contiguous() if orig_ndim != 3 else to_send_f.squeeze(0).contiguous() + + + +def recover_lse( + to_send_lse: torch.Tensor, # [H, S] + process_group: dist.ProcessGroup = None + ): + + if not to_send_lse.is_contiguous(): + to_send_lse = to_send_lse.contiguous() + + to_send_f = torch.zeros_like(to_send_lse) + + block_seq_len = to_send_lse.shape[1] // 2 + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send_lse[:, :block_seq_len].contiguous() + else: + to_send_slice = to_send_lse[:, block_seq_len:].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[:, :block_seq_len] = to_send_lse[:, block_seq_len:] + to_send_f[:, block_seq_len:] = res + else: + to_send_f[:, :block_seq_len] = to_send_lse[:, :block_seq_len] + to_send_f[:, block_seq_len:] = res + + return to_send_f.contiguous() + + +@lru_cache(maxsize=16) +def calc_chunks(cu_seqlen, moba_chunk_size): + """calc chunks that needs moba attention""" + + # batch_sizes[batch_idx] = batch size ( seqlen ) of batch idx + # example: [seq_len] + batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] + + # batch_num_chunk[batch_idx] = how many chunk in batch idx + # example: [number of all chunks with chunk size equal to moba_chunk_size + 1 (the one with smaller size)] + batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size + + # cu_num_chunk[batch_idx] = first chunk id of this batch + # example: [1, 1] + cu_num_chunk = torch.ones( + batch_num_chunk.numel() + 1, + device=cu_seqlen.device, + dtype=batch_num_chunk.dtype, + ) + # example: [1, 1 + num of chunks] + cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0) + + # total chunk ( for all batch ) + # example: 1 + num of chunks + num_chunk = cu_num_chunk[-1] + + # chunk_sizes[chunk_idx] = chunk_size of chunk idx + chunk_sizes = torch.full( + (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device + ) + chunk_sizes[0] = 0 # for calc cu chunk + batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size + chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size + # example chunk_sizes: [0, moba_chunk_size, ..., moba_chunk_size, batch_last_chunk_size] + + + # cu_chunk[chunk_idx] = the start chunk offset of chunk idx + # example: [0, moba_chunk_size, ..., seq_len] + cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32) + + + # chunk_to_batch[chunk_idx] = batch idx of the chunk idx + # example: [0, 0, 0, ...., 0] + chunk_to_batch = torch.zeros( + (num_chunk,), dtype=torch.int32, device=cu_seqlen.device + ) + + # example: [0, 0, 0, ... , 0] (if there are multiple samples in the batch, the index of the starting chunk of each batch from the 1st sample will be 1) + # but if there is only one batch, cu_num_chunk[1:-1] will be empty and no element will be assigned with 1 (all correspond to 0-th sample) + chunk_to_batch[cu_num_chunk[1:-1]] = 1 + + # example: [0, 0, 0, ..., 0] + chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32) + + """ filter chunks that need moba attn """ + # filter chunks ( remove last chunk of each batch ) + # filtered_chunk_indices: chunk index list that excludes the last chunk of each batch + chunk_to_remove = cu_num_chunk[1:] - 1 # example: number of chunks (num_chunk - 1) + # print(f"calc_chunks | chunk_to_remove: {chunk_to_remove}") + + chunk_to_remain = torch.ones( + (num_chunk, ), dtype=torch.bool, device=cu_seqlen.device + ) + chunk_to_remain[chunk_to_remove] = False # example: + filtered_chunk_indices = chunk_to_remain.nonzero(as_tuple=True)[0] + num_filtered_chunk = len(filtered_chunk_indices) + + return ( + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + ) + + +def compute_moba_gate( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offset: torch.Tensor, + cu_seqlens: torch.Tensor, + moba_chunk_size: int, + moba_topk: int, +): + seq_offset: int = seq_offset.detach().cpu().item() + seqlen_block, num_head, head_dim = q.shape + _, k_num_head, _ = k.shape + if num_head > k_num_head: + k = torch.repeat_interleave(k, num_head // k_num_head, dim=1) + v = torch.repeat_interleave(v, num_head // k_num_head, dim=1) + + # --------------------------------------------------------------------------------------------- + kv = torch.stack((k, v), dim=1) # [ blk_S, 2, H, D ] + + world_size = dist.get_world_size() + kv_list = [torch.zeros_like(kv, dtype=q.dtype, device=q.device) for _ in range(world_size)] + dist.all_gather(kv_list, kv) + kv_gathered = torch.cat(kv_list, dim=0) # [ S, 2, H, D ] + + + """ some basic variables """ + # qkv shape = [ S, H, D ] + block_size = q.shape[0] + seqlen, _, num_head, head_dim = kv_gathered.shape + + """ prepare chunk meta """ + ( + cu_chunk, # example: [0, moba_chunk_size, ..., seq_len] + filtered_chunk_indices, # example: [0, 1, 2, ..., num_filtered_chunk-1] (i.e. except the last chunk) + num_filtered_chunk, # example: num_filtered_chunk + chunk_to_batch, # example: [0, 0, ... ,0] (for batch_size=1) with size 1 + real num of chunks + ) = calc_chunks(cu_seqlens, moba_chunk_size) + + # we will adjust selective topk to moba_topk - 1, as the last chunk is always chosen + moba_topk = min(moba_topk - 1, num_filtered_chunk) + assert moba_topk > 0, "moba_topk should be greater than 0" + + # filtered_kv is a dense matrix that only contains filtered chunk of kv + filtered_kv_indices = torch.arange( + 0, moba_chunk_size, dtype=torch.int32, device=q.device + )[None, :].repeat(num_filtered_chunk, 1) + filtered_kv_indices += cu_chunk[filtered_chunk_indices][:, None] + + # select the elements of KV corresponding to all chunks that are not filtered out + filtered_kv = kv_gathered.index_select(0, filtered_kv_indices.view(-1)) + + """ calc key_gate_weight and gate """ + # key_gate_weight [ F_N_CHUNK, HEAD, HEAD_DIM ] + key_gate_weight = ( + filtered_kv[:, 0] # K + .view(num_filtered_chunk, moba_chunk_size, num_head, head_dim) + .mean(dim=1) # mean pooling along chunk size + .float() + ) + # print(f"Rank {dist.get_rank()} | compute_moba_gate | key_gate_weight shape: {key_gate_weight.shape}") + + q = q.type(torch.float32) # float logit on the fly for better gate logit perception + key_gate_weight = key_gate_weight.type( + torch.float32 + ) # float logit for better gate logit perception + gate = torch.einsum( + "nhd,shd->nhs", key_gate_weight, q + ) # gate [ F_N_CHUNK, HEAD, SEQ_BLOCK] + key_gate_weight = key_gate_weight.type_as(k) + q = q.type_as(k) + + # pose process gate, masking unchosen batch and apply causal mask to current chunk + gate_seq_idx = torch.arange( + seq_offset, min(seq_offset + block_size, seqlen), device=q.device, dtype=torch.int32 + )[None, :].repeat(num_filtered_chunk, 1) + chunk_end = cu_chunk[filtered_chunk_indices + 1] + batch_end = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1] + gate_chunk_end_mask = gate_seq_idx < chunk_end[:, None] + gate_batch_end_mask = gate_seq_idx >= batch_end[:, None] + gate_inf_mask = gate_chunk_end_mask | gate_batch_end_mask + gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf")) + # print(f"Rank {dist.get_rank()} | compute_moba_gate | gate shape before topK: {gate.shape}") + + """ find moba q that needs moba attn """ + # find topk chunks + # gate_top_k_idx with shape [TOP_K, HEAD, SEQ_BLOCK] + _, gate_top_k_idx = torch.topk(gate, k=moba_topk, dim=0, largest=True, sorted=False) + # apply causal mask + gate_mask = torch.logical_not(gate.isinf()) + + # select topk chunks + gate_idx_mask = torch.zeros(gate_mask.shape, dtype=torch.bool, device=q.device) + gate_idx_mask = gate_idx_mask.scatter_(dim=0, index=gate_top_k_idx, value=True) + + # [ F_N_CHUNK, HEAD, SEQ_BLOCK] + gate_mask = torch.logical_and(gate_mask, gate_idx_mask).contiguous() + + return ( + # gate_mask does not need to be gathered because + # each device only needs the gate_mask corresponding to the current query block + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch + ) diff --git a/minference/dist_ops/op_utils/xattn_utils.py b/minference/dist_ops/op_utils/xattn_utils.py new file mode 100644 index 0000000..444c36f --- /dev/null +++ b/minference/dist_ops/op_utils/xattn_utils.py @@ -0,0 +1,521 @@ +import math +import torch +import torch.nn.functional as F +import torch.distributed as dist + +from minference.dist_ops.utils import RingComm +from minference.ops.xattention_fa import ( + softmax_fuse_block_sum, + flat_group_gemm_fuse_reshape, +) + + +LN2 = 1 / 1.4426950408889634 +def create_causal_mask(batch_size, head_num, block_size, block_num, divide_block_num): + """ + Creates a causal attention mask used in transformer-based models. + + Parameters: + - batch_size (int): The number of sequences in the batch. + - head_num (int): The number of attention heads. + - block_size (int): The size of each block in the sequence. + - block_num (int): The total number of blocks in the sequence. + - divide_block_num (int): The block index at which causality is applied. + + Returns: + - torch.Tensor: A mask tensor of shape (batch_size, head_num, block_size, total_size) + where total_size = block_size * block_num. The mask enforces causal attention by + setting certain positions to `-inf` to prevent information leakage from future tokens. + """ + divide_block_num += 1 + if divide_block_num < 1 or divide_block_num > block_num: + raise ValueError( + f"divide_block_num ({divide_block_num}) must be between 1 and block_num ({block_num})." + ) + + total_size = block_size * block_num + device = "cuda" + mask = torch.zeros(block_size, total_size, device=device) + if divide_block_num < block_num: + mask[:, divide_block_num * block_size :] = float("-inf") + + if divide_block_num - 1 < block_num: + start_col = (divide_block_num - 1) * block_size + end_col = start_col + block_size + upper_tri_mask = torch.triu( + torch.full((block_size, block_size), float("-inf"), device=device), + diagonal=1, + ) + mask[:, start_col:end_col] = upper_tri_mask + + mask = mask.unsqueeze(0).unsqueeze(0) + mask = mask.expand(batch_size, head_num, block_size, total_size) + return mask + +def find_blocks_chunked( + input_tensor: torch.Tensor, # (batch_size, num_heads, num_block_q, num_block_k) + current_index, # + threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True +): + """ + Finds and selects relevant blocks of attention for transformer-based models based on a + threshold or a predefined number of blocks. + + Parameters: + - input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, num_block_q, num_block_k). + - current_index (int): The current index in the sequence processing. + - threshold (float or None): A threshold value used to determine the minimum attention weight sum. + - num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval. + - decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode. + - mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'. + - causal (bool): If True, applies causal masking to prevent future information leakage. + + Returns: + - torch.Tensor: A boolean mask of shape (batch_size, head_num, num_block_q, num_block_k), + indicating which blocks should be attended to. + """ + assert threshold is None or num_to_choose is None + batch_size, head_num, num_block_q, num_block_k = input_tensor.shape + input_tensor = input_tensor.to(float) + + total_sum = input_tensor.sum(dim=-1, keepdim=True) + if isinstance(threshold, torch.Tensor): + threshold = threshold.to(float) + required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ).expand((batch_size, head_num, num_block_q, 1)).to(input_tensor.device) + else: + required_sum = total_sum * threshold + + + mask = torch.zeros_like(input_tensor, dtype=torch.bool) + mask[:, :, :, 0] = 1 + mask[:, :, :, current_index : current_index + num_block_q] = ( + torch.eye(num_block_q, device=mask.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, num_block_q, num_block_q) + ) + # Note that other_values only contains the values of the current block + # (the sink blocks and diagonal are filled with 0) + other_values = input_tensor.masked_fill(mask, 0) + + + # Get sorted values + sorted_values, _ = torch.sort(other_values, dim=-1, descending=True) + sorted_values = sorted_values.to(input_tensor.device) + sorted_values = torch.cat( + [ + torch.zeros( + (batch_size, head_num, num_block_q, 1), device=input_tensor.device + ), + torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True), # shape: (batch_size, head_num, num_block_q, 1) + sorted_values[:, :, :, :-2], # :-2 excludes the first and diagonal (which are marked 0 in other_values) + ], + dim=-1, + ) + + # Get sorted indices + # index will select the already-masked (sink and diagonal) at the beginning + _, index = torch.sort( + torch.where(mask, 100000 * (1 + input_tensor), input_tensor), + dim=-1, + descending=True, + ) + + # [batch_size, head_num, num_block_q, num_block_k] + cumulative_sum_without_self = torch.cat( + [ + torch.zeros( + (batch_size, head_num, num_block_q, 1), device=input_tensor.device + ), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + + # Mask for indices where cumulative sum is below the required threshold. + index_mask = cumulative_sum_without_self < required_sum + index = torch.where(index_mask, index, 0) + + mask = mask.view(batch_size, head_num * num_block_q, num_block_k) + index = index.view(batch_size, head_num * num_block_q, num_block_k) + mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True + mask = mask.view(batch_size, head_num, num_block_q, num_block_k) + + + assert bool((torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True) >= required_sum * 0.99).all()), \ + f"mask sum {torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True)} < required_sum {required_sum}" + + try: + if causal: + assert (~mask[:, :, :, current_index + num_block_q :]).all() + except: + mask[:, :, :, current_index + num_block_q :] = False + + if causal: + if decoding: + assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() + else: + lambda_mask = torch.zeros_like(input_tensor,dtype=bool,device=input_tensor.device) + lambda_mask[:,:,:,0] = 1 + lambda_mask[:,:,:,current_index:current_index+num_block_q] = torch.eye(num_block_q, device=lambda_mask.device).unsqueeze(0).unsqueeze(0).expand(1,head_num,num_block_q,num_block_q) + assert(torch.where(lambda_mask,mask,True).all()) + + return mask + + +def xattn_estimate( + query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) + key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) + block_size, + stride, + norm=1, + softmax=True, + threshold=0.9, + chunk_size=16384, + select_mode="inverse", + use_triton=True, + causal=True, + kdb: int = 1, + keep_sink=False, + keep_recent=False, +) -> torch.Tensor: + batch_size, num_kv_head, k_len, head_dim = key_states.shape + batch_size, num_q_head, q_len, head_dim = query_states.shape + if num_q_head > num_kv_head: + key_states = torch.repeat_interleave(key_states.contiguous(), num_q_head // num_kv_head, dim=1) + + assert q_len % chunk_size == 0 + assert k_len % chunk_size == 0 + + q_chunk_num = q_len // chunk_size + q_block_num = q_len // block_size + + # assert num_kv_head == num_q_head + attn_sum_list = [] + simple_mask_list = [] + + if use_triton and ( + "100" not in torch.cuda.get_device_properties(torch.cuda.current_device()).name + ): + use_triton = False + print( + "setting use triton to false. Triton kernel not surpported on this device" + ) + + num_strides_in_k = k_len // stride + + num_strides_per_chunk = chunk_size // stride + num_strides_per_block = block_size // stride + num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block + + for chunk_idx in range(q_chunk_num): + if kdb != 1: + raise ValueError("use_triton and kdb cannot be used together") + + q_chunk_start = chunk_idx * num_strides_per_chunk * stride + q_chunk_end = (chunk_idx + 1) * num_strides_per_chunk * stride + + q_chunk_start_stride = chunk_idx * num_strides_per_chunk + q_chunk_end_stride = (chunk_idx + 1) * num_strides_per_chunk + + # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) + # (i.e. the attention sum of each SxS stride block) + # This step is agnostic to block size and just computes the attention sum in each stride block + attn_weights_slice = flat_group_gemm_fuse_reshape( + # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True + query_states[:, :, q_chunk_start : q_chunk_end, :,], + key_states, + stride, + q_chunk_start_stride, + q_chunk_end_stride, + is_causal=causal, + ) + + # (batch_size, num_heads, q_block_num, k_block_num), + attn_sum = softmax_fuse_block_sum( + attn_weights_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) + num_strides_per_block, + min(4096, num_strides_per_block), + q_chunk_start_stride, q_chunk_end_stride, + num_strides_in_k, + 1 / LN2 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + + + # (batch_size, head_num, num_blocks_per_chunk, block_num) + simple_mask = find_blocks_chunked( + attn_sum, + chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + del attn_weights_slice + + attn_sums = torch.cat(attn_sum_list, dim=-2) + + # (batch_size, head_num, num_blocks_per_chunk * q_chunk_num, block_num) + # i.e. (batch_size, head_num, q_block_num, q_block_num) + simple_masks = torch.cat(simple_mask_list, dim=-2) + + if causal: + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril( + torch.ones( + q_block_num, q_block_num, dtype=bool, device=key_states.device + ), + diagonal=0, + ), + simple_masks[:, :, -q_block_num:, -q_block_num:], + False, + ) + # print(f"{__name__} | simple_masks[:, :, -q_block_num:, -q_block_num:].shape {simple_masks[:, :, -q_block_num:, -q_block_num:].shape} after torch.where") + + + if keep_sink: + simple_masks[:, :, 0, :] = True + if keep_recent: + eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool) + eye_matrix_expanded = ( + eye_matrix.unsqueeze(0) + .unsqueeze(0) + .expand(1, num_kv_head, q_block_num, q_block_num) + ) + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:] + ) + + # simple_masks -> (batch_size, head_num, q_block_num, q_block_num) + return attn_sums, simple_masks + +def check_device(use_triton: bool): + avail = use_triton and ( + "100" not in torch.cuda.get_device_properties(torch.cuda.current_device()).name + ) + if not avail: + print("Setting use triton to false. Triton kernel not surpported on this device") + return avail + + + +def xattn_zigzag_estimate( + query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) + key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) + block_size, + stride, + norm=1, + softmax=True, + threshold=0.9, + select_mode="inverse", + use_triton=True, + causal=True, + kdb: int = 1, + keep_sink=False, + keep_recent=False, + group: dist.group = None, +) -> torch.Tensor: + batch_size, num_kv_head, k_len_local, head_dim = key_states.shape + batch_size, num_q_head, q_len_local, head_dim = query_states.shape + + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + k_gather_list = [torch.empty_like(key_states) for _ in range(world_size)] + dist.all_gather(k_gather_list, key_states.contiguous(), group=group) + k_gathered = torch.cat(k_gather_list, dim=2) + k_len = k_gathered.shape[2] + + if num_q_head > num_kv_head: + k_gathered = torch.repeat_interleave(k_gathered.contiguous(), num_q_head // num_kv_head, dim=1) + + chunk_size = q_len_local // 2 + q_chunk_num = 2 + q_block_num = q_len_local // block_size + q_block_num_per_chunk = chunk_size // block_size + + # assert num_kv_head == num_q_head + attn_sum_list = [] + simple_mask_list = [] + + num_strides_in_k = k_len // stride + num_strides_per_chunk = chunk_size // stride + num_strides_per_block = block_size // stride + num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block + + attn_weight_slices = [None, None] + for chunk_idx in range(q_chunk_num): + global_chunk_idx = rank * 2 + chunk_idx + + # Local start index + q_chunk_start = chunk_idx * chunk_size + q_chunk_end = (chunk_idx + 1) * chunk_size + + # Global start index (stride-level) + q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk + q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk + + # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) + # (i.e. the attention sum of each SxS stride block) + # This step is agnostic to block size and just computes the attention sum in each stride block + attn_weight_slice = flat_group_gemm_fuse_reshape( + # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True + query_states[:, :, q_chunk_start : q_chunk_end, :,], + k_gathered, + stride, + q_chunk_start_stride_global, q_chunk_end_stride_global, + is_causal=causal, + ) + attn_weight_slices[chunk_idx] = attn_weight_slice + del k_gathered, k_gather_list + + for chunk_idx in range(q_chunk_num): + global_chunk_idx = rank * 2 + chunk_idx + + # Local start index + q_chunk_start = chunk_idx * chunk_size + q_chunk_end = (chunk_idx + 1) * chunk_size + + # Global start index (block-level) + q_block_start = global_chunk_idx * q_block_num_per_chunk + q_block_end = (global_chunk_idx + 1) * q_block_num_per_chunk + + # Global start index (stride-level) + q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk + q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk + + attn_weight_slice = attn_weight_slices[chunk_idx] + + # (batch_size, num_heads, q_block_num, k_block_num), + attn_sum = softmax_fuse_block_sum( + attn_weight_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) + num_strides_per_block, + min(4096, num_strides_per_block), + q_chunk_start_stride_global, q_chunk_end_stride_global, + num_strides_in_k, + 1 / LN2 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + + # (batch_size, head_num, num_blocks_per_chunk, block_num) + simple_mask = find_blocks_chunked( + attn_sum, + global_chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + del attn_weight_slice + if causal: + simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( + torch.tril( + torch.ones( + q_block_num_per_chunk, q_block_num_per_chunk, + dtype=bool, device=key_states.device + ), + diagonal=0, + ), + simple_mask[:, :, :, q_block_start:q_block_end], + False, + ) + simple_mask[:, :, :, q_block_end:] = 0 + if keep_sink: + simple_mask[:, :, 0, :] = True + if keep_recent: + eye_matrix = torch.eye(q_block_num_per_chunk, device=simple_mask.device, dtype=bool) + eye_matrix_expanded = ( + eye_matrix.unsqueeze(0) + .unsqueeze(0) + .expand(1, num_kv_head, q_block_num_per_chunk, q_block_num_per_chunk) + ) + simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( + eye_matrix_expanded, True, simple_mask[:, :, :, q_block_start:q_block_end] + ) + + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + attn_sums = torch.cat(attn_sum_list, dim=-2) + simple_masks = torch.cat(simple_mask_list, dim=-2) # (batch_size, head_num, q_local_block_num, k_global_block_num) + return attn_sums, simple_masks + + +def shuffle_zigzag_masks( + block_masks: torch.Tensor, # [batch_size, num_qo_heads, num_blocks_local, num_blocks] + process_group: dist.ProcessGroup = None + ): + dim = len(block_masks.shape) - 1 + if not block_masks.is_contiguous(): + block_masks = block_masks.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(block_masks) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_zigzag_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = block_masks.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + to_send_slice = block_masks[right_slicer].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[right_slicer] = block_masks[left_slicer] + to_send_f[left_slicer] = res + else: # A: 0 1, -> 0 7 + to_send_f[left_slicer] = block_masks[left_slicer] + to_send_f[right_slicer] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f diff --git a/minference/dist_ops/ring_attention.py b/minference/dist_ops/ring_attention.py new file mode 100644 index 0000000..7da68f9 --- /dev/null +++ b/minference/dist_ops/ring_attention.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: replace with zhuzilin's implementation + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +from .utils import shuffle_zigzag_input, recover_zigzag_output, GlobalMemoryBuffer + + +_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + up_q = q[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + up_out, _, _, _, _, up_lse, _, _ = _flash_attn_forward( + up_q, + up_k, + up_v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + down_q = q[:, block_len:] + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_out, _, _, _, _, down_lse, _, _ = _flash_attn_forward( + down_q, + down_k, + down_v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + out = torch.cat([up_out, down_out], dim=1) + return out, up_lse, down_lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + up_lse, + down_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + dq = torch.zeros_like(q) + dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_dk") + dk_buffer.zero_() + dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_dv") + dv_buffer.zero_() + + up_q = q[:, :block_len] + up_out = out[:, :block_len] + up_dout = dout[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + _flash_attn_backward( + up_dout, + up_q, + up_k, + up_v, + up_out, + up_lse, + dq[:, :block_len], + dk_buffer[:, :(up_rank + 1) * block_len], + dv_buffer[:, :(up_rank + 1) * block_len], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + down_q = q[:, block_len:] + down_out = out[:, block_len:] + down_dout = dout[:, block_len:] + # TODO: optimize the buffer allocation + down_dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_down_dk") + down_dk_buffer.zero_() + down_dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_down_dv") + down_dv_buffer.zero_() + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + _flash_attn_backward( + down_dout, + down_q, + down_k, + down_v, + down_out, + down_lse, + dq[:, block_len:], + down_dk_buffer[:, :(down_rank + 1) * block_len], + down_dv_buffer[:, :(down_rank + 1) * block_len], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + dk_buffer.add_(down_dk_buffer) + dv_buffer.add_(down_dv_buffer) + + dim_size = list(k.size()) + dim_size[1] = dim_size[1] // world_size + dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) + dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) + dist._reduce_scatter_base(dk, dk_buffer, group=process_group) + dist._reduce_scatter_base(dv, dv_buffer, group=process_group) + + return dq, dk, dv + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, all gather k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, reduce scatter dk, dv +''' +class RingFlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + q = shuffle_zigzag_input(to_send=q, process_group=group) + world_size = dist.get_world_size(group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed._all_gather_base(k_buffer, k, group=group) + torch.distributed._all_gather_base(v_buffer, v, group=group) + + out, up_lse, down_lse = ring_flash_attn_forward( + group, + q, + k_buffer, + v_buffer, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, up_lse, down_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_zigzag_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_zigzag_input(to_send=dout, process_group=ctx.group) + q, k, v, out, up_lse, down_lse = ctx.saved_tensors + world_size = dist.get_world_size(ctx.group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed._all_gather_base(k_buffer, k, group=ctx.group) + torch.distributed._all_gather_base(v_buffer, v, group=ctx.group) + + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k_buffer, + v_buffer, + out, + up_lse, + down_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_zigzag_output(dq, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/minference/dist_ops/striped_attention.py b/minference/dist_ops/striped_attention.py new file mode 100644 index 0000000..bc0915f --- /dev/null +++ b/minference/dist_ops/striped_attention.py @@ -0,0 +1,404 @@ +import torch +import torch.distributed as dist +from typing import List, Tuple, Dict +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +from .utils import ( + RingComm, + update_out_and_lse, get_default_args, + shuffle_striped_input, recover_striped_output +) + +def stripe_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" + comm = RingComm(process_group) + bsz, seq_len, num_heads, head_dim = q.shape + + out, lse = None, None + next_k, next_v = None, None + + def forward(q_, k_, v_, causal_): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q_, + "k": k_, + "v": v_, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal_, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + shift = 1 if step > comm.rank else 0 + if shift == 0: + step_out, step_lse = forward(q, k, v, causal) + out, lse = update_out_and_lse(out, lse, step_out, step_lse) + else: + # Before the step index goes beyond the current rank, the received KV indices are not greater than those of the Q in the current rank + # After the step index goes beyond the current rank, only the KV indices before the last granularity are no greater than those of the Q after the first granularity + # this conclusion holds after the step index goes beyond the current rank (not just step index == current rank) + step_out, step_lse = forward( + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], causal + ) + out, lse = update_out_and_lse( + out, lse, step_out, step_lse, slice_=(slice(None), slice(granularity, None)) + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def stripe_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + layer_idx: int, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert ( + causal + ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + bsz, seq_len, num_heads, head_dim = q.shape + + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + + def backward( + granularity_, + ): + if granularity_ == 0: + k_, v_ = k, v + dk_, dv_ = block_dk_buffer, block_dv_buffer + else: + k_, v_ = k[:, :-granularity_], v[:, :-granularity_] + dk_, dv_ = block_dk_buffer[:, :-granularity_], block_dv_buffer[:, :-granularity_] + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout[:, granularity_:], + "q": q[:, granularity_:], + "k": k_, + "v": v_, + "out": out[:, granularity_:], + "softmax_lse": softmax_lse[:, :, granularity_:].contiguous(), + "dq": block_dq_buffer[:, granularity_:], + "dk": dk_, + "dv": dv_, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + params.update({"rng_state": torch.zeros((2, ), dtype=torch.int64, device=q.device)}) + _flash_attn_backward(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + shift_causal = 1 if step > kv_comm.rank else 0 + if shift_causal == 0: + backward(granularity_=0) + else: + backward(granularity_=granularity) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if shift_causal == 0: + dq += block_dq_buffer + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dq[:, granularity:] += block_dq_buffer[:, granularity:] + dk[:, :-granularity] += block_dk_buffer[:, :-granularity] + dv[:, :-granularity] += block_dv_buffer[:, :-granularity] + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class StripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + # ----------------------------------------- + # Shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + k, v = k.contiguous(), v.contiguous() + + # ---------------------------------------------- + # Compute + out, softmax_lse = stripe_flash_attn_forward( + group, + q, k, v, + layer_idx, + softmax_scale=softmax_scale, + granularity=granularity, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + + # ---------------------------------------------- + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.layer_idx = layer_idx + ctx.return_softmax = return_softmax + ctx.group = group + + # ---------------------------------------------- + # Output and return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + layer_idx = ctx.layer_idx + dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, group = ( + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.alibi_slopes, ctx.deterministic, ctx.return_softmax, + ctx.group + ) + + + # ---------------------------------------------- + # Shuffle + dout = shuffle_striped_input( + to_send=dout, dim=1, granularity=ctx.granularity, + process_group=ctx.group + ) + + # ---------------------------------------------- + # Compute + dq, dk, dv = stripe_flash_attn_backward( + ctx.group, + dout, + q, k, v, out, softmax_lse, + layer_idx=layer_idx, + softmax_scale=ctx.softmax_scale, + granularity=ctx.granularity, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + + + # ---------------------------------------------- + # Recover + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + + +def stripe_flash_attn_qkvpacked_func( + qkv, # [B, N, 3, H, D] + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_func( + q, + k, + v, + layer_idx: int, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + k, + v, + layer_idx, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/utils.py b/minference/dist_ops/utils.py new file mode 100644 index 0000000..79282da --- /dev/null +++ b/minference/dist_ops/utils.py @@ -0,0 +1,523 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import math +import torch +import inspect +import operator +import torch.nn.functional as F +import torch.distributed as dist + +import triton +import triton.language as tl + +from functools import reduce, cache +from typing import Optional, Tuple, List, Dict +from torch.distributed.distributed_c10d import P2POp + +PROCESS_GROUPS: Dict[str, dist.ProcessGroup] = {} + +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +# copy from megatron/core/utils.py +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + +@triton.jit +def _update_out_and_lse_kernel( + Out0, Lse0, Out1, Lse1, + stride_oz0, stride_om0, stride_oh0, stride_od0, + stride_lz0, stride_lm0, stride_lh0, + stride_oz1, stride_om1, stride_oh1, stride_od1, + stride_lz1, stride_lm1, stride_lh1, + num_tokens, + BLOCK_M: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_m = tl.program_id(0) + head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + + if start_m * BLOCK_M >= num_tokens: + return + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_D) + m_mask = offs_m < num_tokens + + o0_ptrs = Out0 + batch_idx * stride_oz0 + head_idx * stride_oh0 + offs_m[:, None] * stride_om0 + offs_d[None, :] * stride_od0 + o1_ptrs = Out1 + batch_idx * stride_oz1 + head_idx * stride_oh1 + offs_m[:, None] * stride_om1 + offs_d[None, :] * stride_od1 + lse0_ptrs = Lse0 + batch_idx * stride_lz0 + head_idx * stride_lh0 + offs_m * stride_lm0 + lse1_ptrs = Lse1 + batch_idx * stride_lz1 + head_idx * stride_lh1 + offs_m * stride_lm1 + + lse0 = tl.load(lse0_ptrs, mask=m_mask, other=float("-inf")) + lse1 = tl.load(lse1_ptrs, mask=m_mask, other=float("-inf")) + o0 = tl.load(o0_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + o1 = tl.load(o1_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + + m_mask &= (lse0 - lse1) < 88.0 + + theta = tl.math.exp(lse0 - lse1) + alpha0 = 1 / (1 + 1 / theta) + alpha1 = 1 / (1 + theta) + o = alpha0[:, None] * o0 + alpha1[:, None] * o1 + lse = lse1 - tl.math.log(alpha1) + + tl.store(o0_ptrs, o.to(Out0.type.element_ty), mask=m_mask[:, None]) + tl.store(lse0_ptrs, lse, mask=m_mask) + + +def _update_out_and_lse_triton( + out: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_tokens, num_heads, 1] + block_out: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + block_lse: torch.Tensor, # [batch_size, num_heads, num_tokens] => [batch_size, num_tokens, num_heads, 1] + step_idx: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + batch_size, num_tokens, num_heads, head_dim = out.shape + block_M = 128 + block_D = head_dim + _update_out_and_lse_kernel[(triton.cdiv(num_tokens, block_M), num_heads, batch_size)]( + out, lse, block_out, block_lse, + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + block_out.stride(0), block_out.stride(1), block_out.stride(2), block_out.stride(3), + block_lse.stride(0), block_lse.stride(1), block_lse.stride(2), + num_tokens, BLOCK_M=block_M, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + return out, lse + + +@torch.jit.script +def _update_out_and_lse_torch( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + step_idx: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, + step_idx: Optional[int] = None, + use_triton_kernel: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + if use_triton_kernel: + _update_out_and_lse = _update_out_and_lse_triton + else: + _update_out_and_lse = _update_out_and_lse_torch + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse, + step_idx=step_idx + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse( + out, lse, block_out, block_lse, + step_idx=step_idx + ) + + return out, lse + + +class RingComm: + def __init__( + self, + process_group: dist.ProcessGroup, + zigzag: bool = False, + ring_list: Optional[list] = None, + ): + self._ops: List[P2POp] = [] + self.rank = dist.get_rank(process_group) + self.world_size = dist.get_world_size(process_group) + self._reqs = None + self.process_group = process_group + + if ring_list is not None: + curr_idx = ring_list.index(self.rank) + self.send_rank = ring_list[(curr_idx + 1) % len(ring_list)] + self.recv_rank = ring_list[(curr_idx - 1 + len(ring_list)) % len(ring_list)] + self.send_first = curr_idx % 2 == 0 + elif zigzag: + parts = self.world_size // 2 + self.ring_list = [] + for i in range(parts): + self.ring_list.extend([i, self.world_size - i - 1]) + self.revert_rank = self.ring_list.index(self.rank) + offset = ((dist.get_rank() // self.world_size) * self.world_size) + self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset + self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset + self.send_first = self.revert_rank % 2 == 0 + else: + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + self.send_first = self.rank % 2 == 0 + + if len(PROCESS_GROUPS) == 0: + self.init_process_groups() + + + if self.send_rank in get_inner_ring(process_group): + outer_rank = get_outer_ring(process_group).index(self.rank) + self._send_group = PROCESS_GROUPS[f'inner-{outer_rank}-{int(self.send_first)}'] + else: + self._send_group = PROCESS_GROUPS[f'outer-{int(self.send_first)}'] + + if self.recv_rank in get_inner_ring(process_group): + outer_rank = get_outer_ring(process_group).index(self.rank) + self._recv_group = PROCESS_GROUPS[f'inner-{outer_rank}-{int(1 - self.send_first)}'] + else: + self._recv_group = PROCESS_GROUPS[f'outer-{int(1 - self.send_first)}'] + + self._send_group = PROCESS_GROUPS[f'inner-0-0'] + self._recv_group = PROCESS_GROUPS[f'inner-0-0'] + + def init_process_groups(self): + global PROCESS_GROUPS + num_nodes = int(os.environ.get("NUM_NODES", 1)) + fast_nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) + # fast_nccl_options.config.max_ctas = 2147483647 + # fast_nccl_options.config.min_ctas = 128 + for node_idx in range(num_nodes): + PROCESS_GROUPS[f'inner-{node_idx}-0'] = dist.new_group(pg_options=fast_nccl_options, use_local_synchronization=True) + PROCESS_GROUPS[f'inner-{node_idx}-1'] = dist.new_group(pg_options=fast_nccl_options, use_local_synchronization=True) + slow_nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) + slow_nccl_options.config.max_ctas = 1 + slow_nccl_options.config.min_ctas = 1 + PROCESS_GROUPS['outer-0'] = dist.new_group(pg_options=slow_nccl_options, use_local_synchronization=True) + PROCESS_GROUPS['outer-1'] = dist.new_group(pg_options=slow_nccl_options, use_local_synchronization=True) + + def send_recv( + self, + to_send: torch.Tensor, + recv_tensor: Optional[torch.Tensor] = None, + step_idx: int = 0, + fwd: int = 1, + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + if self.send_first: + self._reqs.append(dist.isend(to_send, self.send_rank, group=self.process_group)) + self._reqs.append(dist.irecv(res, self.recv_rank, group=self.process_group)) + else: + self._reqs.append(dist.irecv(res, self.recv_rank, group=self.process_group)) + self._reqs.append(dist.isend(to_send, self.send_rank, group=self.process_group)) + + return res + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self._reqs = [] + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + return next_k, next_v + + def send_recv_kv_offsets( + self, + k: torch.Tensor, + v: torch.Tensor, + kv_seq_offsets: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + kv_seq_offsets_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self._reqs = [] + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + next_kv_seq_offsets = self.send_recv(kv_seq_offsets, kv_seq_offsets_buffer) + return next_k, next_v, next_kv_seq_offsets + + +def shuffle_zigzag_input(to_send: torch.Tensor, + dim: int = 1, + process_group: dist.ProcessGroup = None): + dim %= len(to_send.shape) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_zigzag_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = to_send.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + to_send_slice = to_send[right_slicer].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[right_slicer] = to_send[left_slicer] + to_send_f[left_slicer] = res + else: # A: 0 1, -> 0 7 + to_send_f[left_slicer] = to_send[left_slicer] + to_send_f[right_slicer] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + + +def recover_zigzag_output(to_send: torch.Tensor, + dim: int = 1, + process_group: dist.ProcessGroup = None): + dim %= len(to_send.shape) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + to_send_f = torch.zeros_like(to_send) + + block_seq_len = to_send.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send[left_slicer].contiguous() + else: + to_send_slice = to_send[right_slicer].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[left_slicer] = to_send[right_slicer] + to_send_f[right_slicer] = res + else: + to_send_f[left_slicer] = to_send[left_slicer] + to_send_f[right_slicer] = res + + return to_send_f.contiguous() + + +def shuffle_block_mask_zigzag( + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + num_blocks_per_chunk: int, + group: dist.ProcessGroup, +): + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + # --------------------------------------- + # Shuffle Query chunks + block_mask = shuffle_zigzag_input(to_send=block_mask, dim=-2, process_group=group) # [batch_size, num_qo_heads, num_blocks_local, num_blocks] + + # --------------------------------------- + # Shuffle Key chunks + ring_list = RingComm(group, zigzag=True).ring_list + ring_index = ring_list.index(rank) + + shuffled_block_mask_list = [] + for i in range(world_size): + rank_src = ring_list[(ring_index - i) % world_size] + + curr_chunk_index = 2 * rank_src + rev_chunk_index = (2 * world_size - 1 - curr_chunk_index) + if curr_chunk_index > rev_chunk_index: + curr_chunk_index, rev_chunk_index = rev_chunk_index, curr_chunk_index + + shuffled_block_mask_list.append( + torch.cat( + [ + block_mask[..., curr_chunk_index * num_blocks_per_chunk : (curr_chunk_index + 1) * num_blocks_per_chunk], + block_mask[..., rev_chunk_index * num_blocks_per_chunk : (rev_chunk_index + 1) * num_blocks_per_chunk] + ], dim=-1 + ) + ) + block_mask = torch.stack(shuffled_block_mask_list, dim=0).contiguous() # [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + return block_mask + + +def shuffle_striped_input(to_send: torch.Tensor, # [B, N / W, H, D] + granularity: int = 1, + dim: int = 1, + process_group: dist.ProcessGroup = None): + # 00, 01, 02, 03, 04, 05, 06, 07 => 00, 04, 08, 12, 16, 20, 24, 28 + # 08, 09, 10, 11, 12, 13, 14, 15 => 01, 05, 09, 13, 17, 21, 25, 29 + # 16, 17, 18, 19, 20, 21, 22, 23 => 02, 06, 10, 14, 18, 22, 26, 30 + # 24, 25, 26, 27, 28, 39, 30, 31 => 03, 07, 11, 15, 19, 23, 27, 31 + shape = to_send.shape + dim %= len(shape) + world_size = dist.get_world_size(process_group) + input_reshape = to_send.reshape((*shape[:dim], -1, world_size * granularity, *shape[dim+1:])) + input_list = [x.contiguous() for x in input_reshape.split(granularity, dim=dim+1)] # [N / W / (W * G), W*, G] + output_list = [torch.empty_like(x) for x in input_list] # [W*, N / W / (W * G), G] + + + dist.all_to_all(output_list, input_list, group=process_group) + return torch.stack(output_list, dim=dim).reshape(shape).contiguous() + + +def recover_striped_output(to_send: torch.Tensor, # [B, N / W, H, D] + granularity: int = 1, + dim: int = 1, + process_group: dist.ProcessGroup = None): + # 00, 04, 08, 12, 16, 20, 24, 28 => 00, 01, 02, 03, 04, 05, 06, 07 + # 01, 05, 09, 13, 17, 21, 25, 29 => 08, 09, 10, 11, 12, 13, 14, 15 + # 02, 06, 10, 14, 18, 22, 26, 30 => 16, 17, 18, 19, 20, 21, 22, 23 + # 03, 07, 11, 15, 19, 23, 27, 31 => 24, 25, 26, 27, 28, 39, 30, 31 + shape = to_send.shape + dim %= len(shape) + world_size = dist.get_world_size(process_group) + + input_reshape = to_send.reshape((*shape[:dim], world_size, -1, granularity, *shape[dim+1:])) + input_list = [x.squeeze(dim).contiguous() for x in input_reshape.split(1, dim=dim)] # [W*, N / W / (W * G), G] + output_list = [torch.empty_like(x) for x in input_list] # [N / W / (W * G), W*, G] + + dist.all_to_all(output_list, input_list, group=process_group) + return torch.stack(output_list, dim=dim+1).reshape(shape).contiguous() + +# -------------------------------------------------------------------- +# Double-Ring Related +def get_inner_ring(group: dist.ProcessGroup): + rank = dist.get_rank(group) + local_rank = int(os.environ["LOCAL_RANK"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + assert rank % local_world_size == local_rank + return [i + (rank - local_rank) for i in range(local_world_size)] + + +def get_outer_ring(group: dist.ProcessGroup): + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + local_rank = int(os.environ["LOCAL_RANK"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + assert rank % local_world_size == local_rank + return [i * local_world_size + local_rank for i in range(world_size // local_world_size)] + + diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py new file mode 100644 index 0000000..7fe7929 --- /dev/null +++ b/minference/dist_ops/xattn_zigzag.py @@ -0,0 +1,562 @@ +import os +import math +import torch +import triton +import torch.distributed as dist +from typing import List, Tuple, Dict, Any, Optional + +from .utils import ( + RingComm, update_out_and_lse, + shuffle_zigzag_input, recover_zigzag_output, + shuffle_block_mask_zigzag, +) +from .op_utils.xattn_utils import LN2, find_blocks_chunked + +from minference.ops.utils import convert_blockmask +from minference.ops.minference_attn import block_attn_fwd, block_attn_bwd +from minference.ops.minference_attn_triton import triton_block_attn_fwd, triton_block_attn_bwd +from minference.ops.xattention_fa import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum + + +def xattn_zigzag_estimate( + query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) + key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) + block_size, + stride, + norm=1, + softmax=True, + threshold=0.9, + select_mode="inverse", + use_triton=True, + causal=True, + kdb: int = 1, + keep_sink=False, + keep_recent=False, + group: dist.group = None, +) -> torch.Tensor: + batch_size, num_kv_head, k_len_local, head_dim = key_states.shape + batch_size, num_q_head, q_len_local, head_dim = query_states.shape + + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + k_gather_list = [torch.empty_like(key_states) for _ in range(world_size)] + dist.all_gather(k_gather_list, key_states.contiguous(), group=group) + k_gathered = torch.cat(k_gather_list, dim=2) + k_len = k_gathered.shape[2] + + if num_q_head > num_kv_head: + k_gathered = torch.repeat_interleave(k_gathered.contiguous(), num_q_head // num_kv_head, dim=1) + + chunk_size = q_len_local // 2 + q_chunk_num = 2 + q_block_num = q_len_local // block_size + q_block_num_per_chunk = chunk_size // block_size + + # assert num_kv_head == num_q_head + attn_sum_list = [] + simple_mask_list = [] + + num_strides_in_k = k_len // stride + num_strides_per_chunk = chunk_size // stride + num_strides_per_block = block_size // stride + num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block + + attn_weight_slices = [None, None] + for chunk_idx in range(q_chunk_num): + global_chunk_idx = rank * 2 + chunk_idx + + # Local start index + q_chunk_start = chunk_idx * chunk_size + q_chunk_end = (chunk_idx + 1) * chunk_size + + # Global start index (stride-level) + q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk + q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk + + # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) + # (i.e. the attention sum of each SxS stride block) + # This step is agnostic to block size and just computes the attention sum in each stride block + attn_weight_slice = flat_group_gemm_fuse_reshape( + # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True + query_states[:, :, q_chunk_start : q_chunk_end, :,], + k_gathered, + stride, + q_chunk_start_stride_global, q_chunk_end_stride_global, + is_causal=causal, + ) + attn_weight_slices[chunk_idx] = attn_weight_slice + del k_gathered, k_gather_list + + for chunk_idx in range(q_chunk_num): + global_chunk_idx = rank * 2 + chunk_idx + + # Local start index + q_chunk_start = chunk_idx * chunk_size + q_chunk_end = (chunk_idx + 1) * chunk_size + + # Global start index (block-level) + q_block_start = global_chunk_idx * q_block_num_per_chunk + q_block_end = (global_chunk_idx + 1) * q_block_num_per_chunk + + # Global start index (stride-level) + q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk + q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk + + attn_weight_slice = attn_weight_slices[chunk_idx] + + # (batch_size, num_heads, q_block_num, k_block_num), + attn_sum = softmax_fuse_block_sum( + attn_weight_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) + num_strides_per_block, + min(4096, num_strides_per_block), + q_chunk_start_stride_global, q_chunk_end_stride_global, + num_strides_in_k, + 1 / LN2 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + + # (batch_size, head_num, num_blocks_per_chunk, block_num) + simple_mask = find_blocks_chunked( + attn_sum, + global_chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + del attn_weight_slice + if causal: + simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( + torch.tril( + torch.ones( + q_block_num_per_chunk, q_block_num_per_chunk, + dtype=bool, device=key_states.device + ), + diagonal=0, + ), + simple_mask[:, :, :, q_block_start:q_block_end], + False, + ) + simple_mask[:, :, :, q_block_end:] = 0 + if keep_sink: + simple_mask[:, :, 0, :] = True + if keep_recent: + eye_matrix = torch.eye(q_block_num_per_chunk, device=simple_mask.device, dtype=bool) + eye_matrix_expanded = ( + eye_matrix.unsqueeze(0) + .unsqueeze(0) + .expand(1, num_kv_head, q_block_num_per_chunk, q_block_num_per_chunk) + ) + simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( + eye_matrix_expanded, True, simple_mask[:, :, :, q_block_start:q_block_end] + ) + + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + attn_sums = torch.cat(attn_sum_list, dim=-2) + simple_masks = torch.cat(simple_mask_list, dim=-2) # (batch_size, head_num, q_local_block_num, k_global_block_num) + return attn_sums, simple_masks + + +def compute_sr_flops( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + step: int, + granularity: int, + q_len: int, + head_dim: int, + fwd: bool=True, +): + num_blocks = triton.cdiv(q_len, granularity) + bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] + + total_num_blocks = bh * num_blocks * num_blocks / 2 + + num_active_blocks = block_mask_offset.sum(dtype=torch.float32).item() + if step == 0: + num_active_blocks -= bh * num_blocks / 2 + + block_ratio = num_active_blocks / total_num_blocks + sparsity_ratio = 1 - block_ratio + + block_flops = num_active_blocks * (granularity * granularity) * head_dim * 2 * 2 + + if not fwd: block_flops *= 2.5 + return sparsity_ratio, block_flops + + +def compute_sr_by_heads( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + step: int, + granularity: int, + q_len: int, +): + batch_size, num_heads = block_mask_offset.shape[0], block_mask_offset.shape[1] + num_blocks = triton.cdiv(q_len, granularity) + + total_num_blocks = batch_size * num_blocks * num_blocks / 2 + total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) + + + num_active_blocks = block_mask_offset.sum(-1).sum(-1).sum(0, dtype=torch.float32) # [num_qo_heads] + if step == 0: + num_active_blocks -= batch_size * num_blocks / 2 + + block_ratio_by_heads = num_active_blocks / total_num_blocks_by_heads + sparsity_ratio_by_heads = 1 - block_ratio_by_heads + + return sparsity_ratio_by_heads.detach().cpu().numpy().tolist() + +def use_triton(): + return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" + +def xattn_zigzag_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + layer_idx: int, + softmax_scale: float, + granularity: int = 128, + block_idx: Optional[torch.Tensor] = None, + block_cnt: Optional[torch.Tensor] = None, +): + comm = RingComm(process_group, zigzag=True) + out, lse = None, None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + # [batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + block_mask_step = block_mask[step] + block_causal = step == 0 + + if use_triton(): + # TODO: block_mask here needs to be converted to block_idx before passing to triton + block_out, block_lse = triton_block_attn_fwd( + q, k, v, + block_idx=block_idx[step], block_cnt=block_cnt[step], + softmax_scale=softmax_scale, + granularity=granularity, + causal=block_causal, + step=step, + ) + else: + block_out, block_lse = block_attn_fwd( + q, k, v, + block_mask=block_mask_step, + softmax_scale=softmax_scale, + granularity=granularity, + causal=block_causal, + step_idx=step, + ) + + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + +def xattn_zigzag_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + granularity: int = 128, + block_idx: Optional[torch.Tensor] = None, # [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks] + block_cnt: Optional[torch.Tensor] = None, # [world_size, batch_size, num_qo_heads, num_blocks_local] +): + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + block_causal = step == 0 + block_mask_step = block_mask[step] + + # -------------------------------- + # Block Mask + if use_triton(): + step_dq, step_dk, step_dv = triton_block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_idx[step], block_cnt[step], + granularity=granularity, + deterministic=False, + causal=block_causal, + step=step, + ) + else: + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask_step, + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dq += step_dq + dk += step_dk + dv += step_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +class XAttnZigzagFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx, + xattn_params, # Dict[str, Any] + granularity, + causal, + softmax_scale, + return_softmax, + deterministic, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + + # ---------------------------------------------- + # Index Building + # block_mask [batch_size, num_qo_heads, num_blocks_local, num_blocks] + _, block_mask = xattn_zigzag_estimate( + q.transpose(1, 2), k.transpose(1, 2), + block_size=granularity, + **xattn_params + ) + + # ------------------------------------------------------------------ + # QKV Shuffling + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + + # ------------------------------------------------------------------ + # Index Shuffling + block_mask = shuffle_block_mask_zigzag( + block_mask, num_blocks_per_chunk=q.shape[1] // 2 // granularity, + group=group + ).to(q.device) + if use_triton(): + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + else: + block_idx, block_cnt = None, None + block_mask = block_mask.contiguous() + + # ---------------------------------------------- + # Compute + out, softmax_lse = xattn_zigzag_forward( + group, + q, k, v, + block_mask, + layer_idx, + softmax_scale, + granularity=granularity, + block_idx=block_idx, block_cnt=block_cnt, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + + # ------------------------------- + # Variale Saving + if use_triton(): + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, block_idx, block_cnt) + else: + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # ------------------------------- + # Recover outputs + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + if use_triton(): + q, k, v, out, softmax_lse, block_mask, block_idx, block_cnt = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, block_mask = ctx.saved_tensors + block_idx, block_cnt = None, None + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = xattn_zigzag_backward( + group, + dout, q, k, v, + out, softmax_lse, + layer_idx, + softmax_scale, + block_mask, + granularity, + block_idx=block_idx, block_cnt=block_cnt, + ) + + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def xattn_zigzag_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + layer_idx: int, + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return XAttnZigzagFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + layer_idx, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + group, + ) + + +def xattn_zigzag_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + layer_idx: int, + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return XAttnZigzagFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + layer_idx, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + group, + ) + + +def xattn_zigzag_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + layer_idx: int, + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return XAttnZigzagFunc.apply( + q, k, v, + layer_idx, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + group, + ) diff --git a/minference/dist_ops/zigzag_attention.py b/minference/dist_ops/zigzag_attention.py new file mode 100644 index 0000000..a5e4ab2 --- /dev/null +++ b/minference/dist_ops/zigzag_attention.py @@ -0,0 +1,412 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import copy +import torch +import torch.distributed as dist + +from time import perf_counter +from typing import List, Tuple, Dict +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +from .utils import ( + RingComm, update_out_and_lse, shuffle_zigzag_input, + recover_zigzag_output, get_default_args +) + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, # [B, S, H, D] + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group, zigzag=True) + + bsz, seq_len, num_heads, head_dim = q.shape + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + if step == 0: + # Do softmax(QK^T / sqrt(d_k))V on the currently hold K and V + # and record the output and the LSE + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, k, v, out, + layer_idx: int, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + bsz, seq_len, num_heads, head_dim = q.shape + + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout_, q_, k_, v_, out_, softmax_lse_, causal_): + seqlen_q = q_.shape[1] + seqlen_kv = k_.shape[1] + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout_, + "q": q_, + "k": k_, + "v": v_, + "out": out_, + "softmax_lse": softmax_lse_, + "dq": dq_buffer[:, :seqlen_q], + "dk": dk_buffer[:, :seqlen_kv], + "dv": dv_buffer[:, :seqlen_kv], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal_, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + params.update({"rng_state": torch.zeros((2, ), dtype=torch.int64, device=q.device)}) + + _flash_attn_backward(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + # ----------------------------------------------------------- + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal_=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal_=False) + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal_=False) + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dq += dq_buffer + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + dk += dk_buffer + dv += dv_buffer + + # ----------------------------------------------------------- + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer, + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + assert alibi_slopes is None + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + + # ---------------------------------------------- + # Shuffle + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + k, v = k.contiguous(), v.contiguous() + + # ---------------------------------------------- + # Compute + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, k, v, + layer_idx, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + + # ------------------------------ + # Saving tensors + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + ctx.layer_idx = layer_idx + ctx.return_softmax = return_softmax + + # ---------------------------------------------- + # Output and return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + layer_idx = ctx.layer_idx + dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, group = ( + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.alibi_slopes, ctx.deterministic, ctx.return_softmax, + ctx.group + ) + + # ---------------------------------------------- + # Shuffle + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = zigzag_ring_flash_attn_backward( + group, + dout, + q, k, v, out, + layer_idx, + softmax_lse, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + ) + + + # ---------------------------------------------- + # Recover + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + layer_idx, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + layer_idx, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/ops/minference_attn.py b/minference/ops/minference_attn.py new file mode 100644 index 0000000..8396921 --- /dev/null +++ b/minference/ops/minference_attn.py @@ -0,0 +1,881 @@ +import os +import sys +import math + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +import triton +import triton.language as tl + +from typing import List, Tuple + +# Save current flags +if torch.version.hip is None: + original_flags = sys.getdlopenflags() + try: + sys.setdlopenflags(os.RTLD_LAZY | os.RTLD_GLOBAL) + import block_sparse_attn_cuda + from block_sparse_attn.block_sparse_attn_interface import convert_blockmask_row_reverse, convert_blockmask_col_reverse + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + finally: + # Restore original flags for future imports + sys.setdlopenflags(original_flags) + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + +from .utils import build_index_local + + +def block_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_mask: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + granularity: int, + causal: bool, + step_idx: int=-1, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + cu_seqlens = torch.arange(0, (batch_size + 1) * num_tokens, step=num_tokens, dtype=torch.int32, device=q.device) + head_mask_type = torch.ones((num_qo_heads, ), dtype=torch.int32, device=q.device) # Block-Sparse + streaming_info = torch.zeros((num_qo_heads * 2), dtype=torch.int32, device=q.device) + row_blockmask = convert_blockmask_row_reverse(block_mask, causal=True) + + p_dropout = 0.0 + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block( + q.reshape((-1, num_qo_heads, head_dim)), + k.reshape((-1, num_kv_heads, head_dim)), + v.reshape((-1, num_kv_heads, head_dim)), + cu_seqlens, cu_seqlens, + granularity, granularity, + head_mask_type, + streaming_info, + row_blockmask, + num_tokens, num_tokens, + p_dropout, + softmax_scale, + causal, # is_causal + False, # exact_streaming + False, # return_softmax + -1, # window_size_left + -1, # window_size_right + None + ) + out = out.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) + return out, softmax_lse + + +def block_attn_bwd( + grad: torch.Tensor, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_mask: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + granularity: int, + deterministic: bool, + causal: bool, + converted: bool = False, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + cu_seqlens = torch.arange(0, (batch_size + 1) * num_tokens, step=num_tokens, dtype=torch.int32, device=q.device) + head_mask_type = torch.ones((num_qo_heads, ), dtype=torch.int32, device=q.device) # Block-Sparse + streaming_info = torch.zeros((num_qo_heads * 2), dtype=torch.int32, device=q.device) + if converted: + col_blockmask = block_mask + else: + col_blockmask = convert_blockmask_col_reverse(block_mask, causal=True) + p_dropout = 0.0 + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dq, dk, dv, softmax_d = block_sparse_attn_cuda.bwd_block( + grad.reshape((-1, num_qo_heads, head_dim)), + q.reshape((-1, num_qo_heads, head_dim)), + k.reshape((-1, num_kv_heads, head_dim)), + v.reshape((-1, num_kv_heads, head_dim)), + o.reshape((-1, num_qo_heads, head_dim)), + softmax_lse, + dq.reshape((-1, num_qo_heads, head_dim)), + dk.reshape((-1, num_kv_heads, head_dim)), + dv.reshape((-1, num_kv_heads, head_dim)), + cu_seqlens, cu_seqlens, + granularity, granularity, + head_mask_type, + streaming_info, + col_blockmask, + num_tokens, num_tokens, + p_dropout, + softmax_scale, + True, # zero_tensors + causal, # is_causal + -1, # window_size_left + -1, # window_size_right + deterministic, + None, None + ) + dq = dq.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) + dk = dk.reshape((batch_size, num_tokens, num_kv_heads, head_dim)) + dv = dv.reshape((batch_size, num_tokens, num_kv_heads, head_dim)) + return dq, dk, dv + + +@triton.jit +def _triton_bar_attn_fwd_kernel( + Q, K, V, sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_cz, stride_ch, stride_cm, stride_cr, + stride_iz, stride_ih, stride_im, stride_in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + m_mask = offs_m < num_tokens + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + step * stride_cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + (step + 1) * stride_cr) + bar_idx_ptr = bar_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + start_m * stride_im + + if bar_l >= bar_r: + return + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=m_mask[:, None], other=0) + q = (q * qk_scale).to(Q.type.element_ty) + + # loop over k, v and update accumulator + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(m_mask[:, None] & n_mask[None, :], qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + s_0 = tl.load(lse_ptrs, mask=m_mask, other=float("-inf")) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty), mask=m_mask[:, None]) + tl.store(lse_ptrs, s, mask=(m_mask & overflow_mask)) + +def stable_sigmoid(x: torch.Tensor): + return torch.where( + x >= 0, + 1 / (1 + torch.exp(-x)), + torch.exp(x) / (1 + torch.exp(x)) + ) + +def naive_bar_attn_fwd( + q, k, v, + sm_scale, + bar_cnt, bar_idx, + out, softmax_lse, + step, granularity, BLOCK_N=64 + ): + """ + Naive PyTorch implementation of the Triton bar attention forward kernel. + + Args: + q: Query tensor of shape [B, num_qo_heads, num_tokens, head_dim] + k: Key tensor of shape [B, num_kv_heads, num_tokens, head_dim] + v: Value tensor of shape [B, num_kv_heads, num_tokens, head_dim] + sm_scale: A scalar (float) softmax scale. + bar_cnt: Tensor of shape [B, num_qo_heads, num_blocks, world_size+1] + where each block (row) holds bar boundary indices. + bar_idx: Tensor of shape [B, num_qo_heads, num_blocks, nnz_v] + containing indices of keys (columns) for each block. + out: Output tensor of shape [B, num_qo_heads, num_tokens, head_dim]. + This is assumed to have a previous value to merge with. + softmax_lse: Tensor of shape [B, num_qo_heads, num_tokens] containing + the previous log-sum-exp values. + step: integer step indicating which pair of boundaries to use in bar_cnt. + granularity: BLOCK_M, i.e. the number of query tokens processed per block. + BLOCK_N: Block size for the key dimension (default: 64) + + This function updates `out` and `softmax_lse` in-place. + """ + # Get dimensions from q. + B, num_tokens, num_qo_heads, head_dim = q.shape + + # Determine number of query blocks (each corresponding to a row in bar_cnt/bar_idx). + num_blocks = math.ceil(num_tokens / granularity) + + # Compute the ratio for mapping query-head to key/value head. + head_ratio = num_qo_heads // k.shape[2] # since k.shape[1] is num_kv_heads + + # Precompute scale for q: note that 1.44269504 = log2(e) + qk_scale = sm_scale * 1.44269504 + + ln2 = 0.69314718 # constant for converting exp2 to exp + + # Loop over batch and query-head + for b in range(B): + for qh in range(num_qo_heads): + # corresponding key/value head index + kvh = qh // head_ratio + + # Loop over query blocks (rows) + for block in range(num_blocks): + start_m = block * granularity + end_m = min(start_m + granularity, num_tokens) + block_size = end_m - start_m + + # Get bar boundaries for this block & step: + # bar_cnt is assumed to store cumulative indices per block. + bar_l = bar_cnt[b, qh, block, step].item() # starting index (inclusive) + bar_r = bar_cnt[b, qh, block, step + 1].item() # ending index (exclusive) + if bar_l >= bar_r: + continue # nothing to do in this block + + # Initialize accumulators per query token in the block. + # m_i tracks the running maximum (in "log2" domain). + m_i = torch.full((block_size,), -float('inf'), device=q.device, dtype=torch.float32) + # l_i tracks the running sum-of-weights. + l_i = torch.zeros(block_size, device=q.device, dtype=torch.float32) + # acc accumulates the weighted sum of values. + acc = torch.zeros((block_size, head_dim), device=q.device, dtype=torch.float32) + + # Load and scale the q block. + # Shape: [block_size, head_dim] + q_block = q[b, start_m:end_m, qh, :] * qk_scale + + # Loop over key indices in steps of BLOCK_N + for n_start in range(bar_l, bar_r, BLOCK_N): + n_end = min(n_start + BLOCK_N, bar_r) + + # Load column indices from bar_idx. + # bar_idx shape: [nnz_v] for this block. + cols = bar_idx[b, qh, block, n_start:n_end] + cols = cols.long() + + k_selected = k[b, cols, kvh, :] # shape: [n_valid, head_dim] + v_selected = v[b, cols, kvh, :] # shape: [n_valid, head_dim] + + # Compute scaled dot product: [block_size, head_dim] x [head_dim, n_valid] + # Result: [block_size, n_valid] + qk = torch.matmul(q_block, k_selected.T) + + # Numerically stable softmax update in the log2 domain. + # m_i_new = max(m_i, max(qk, dim=1)) + cur_max, _ = qk.max(dim=1) + m_i_new = torch.max(m_i, cur_max) + + alpha = torch.exp2((m_i - m_i_new)) + p = torch.exp2((qk - m_i_new.unsqueeze(1))) + + # Update acc and l_i. + # Scale previous acc by alpha. + acc = acc * alpha.unsqueeze(1) + torch.matmul(p.to(q.dtype), v_selected) + + l_i = l_i * alpha + p.sum(dim=1) + + # Update m_i to the new maximum. + m_i = m_i_new + + # check zeros in l_i, if any, print out the indices + if (l_i == 0).any(): + zero_indices = torch.nonzero(l_i == 0).squeeze() + print(f"Rank {dist.get_rank()} | Zeros in l_i (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {zero_indices}") + + # Finalize the block output. + # Compute weighted output. + acc_1 = acc / l_i.unsqueeze(1) + s_1 = m_i * ln2 + torch.log(l_i) + # check positive infinity in s_1, if any, print out the indices + if torch.isinf(s_1).any() and ( torch.isinf(s_1) & (s_1 > 0) ).any(): + mask = torch.isinf(s_1) & (s_1 > 0) + print(f"Rank {dist.get_rank()} | Positive infinity in s_1 (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") + # check negative infinity in s_1, if any, print out the indices + if torch.isinf(s_1).any() and ( torch.isinf(s_1) & (s_1 < 0) ).any(): + mask = torch.isinf(s_1) & (s_1 < 0) + print(f"Rank {dist.get_rank()} | Negative infinity in s_1 (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") + + # Load previous stored values (accumulated output and LSE). + old_out = out[b, start_m:end_m, qh, :].to(acc_1.dtype) + old_lse = softmax_lse[b, qh, start_m:end_m] + # check positive infinity in old_lse, if any, print out the indices + if torch.isinf(old_lse).any() and ( torch.isinf(old_lse) & (old_lse > 0) ).any(): + mask = torch.isinf(old_lse) & (old_lse > 0) + print(f"Rank {dist.get_rank()} | Positive infinity in old_lse (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") + + # ------------------------------------------------- + # Logsigmoid solution + # out - old_out, block_out - acc1, lse - old_lse, block_lse - s_1 + new_out = old_out - F.sigmoid(s_1 - old_lse).unsqueeze(1) * (old_out - acc_1) + new_lse = s_1 - F.logsigmoid(s_1 - old_lse) + if torch.isinf(new_lse).any() and ( torch.isinf(new_lse) & (new_lse > 0) ).any(): + mask = torch.isinf(new_lse) & (new_lse > 0) + print(f"Rank {dist.get_rank()} | Positive infinity in new_lse (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") + + pos_inf_indices = torch.nonzero(mask).squeeze() + print(f"Rank {dist.get_rank()} | Values of (old_lse - s_1) resulting in pos-inf in theta (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {(old_lse - s_1)[pos_inf_indices]}") + print(f"Rank {dist.get_rank()} | Values of (old_lse) resulting in pos-inf in theta (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {(old_lse)[pos_inf_indices]}") + print(f"Rank {dist.get_rank()} | Values of (s_1) resulting in pos-inf in theta (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {(s_1)[pos_inf_indices]}") + + # Store back into out and softmax_lse. + out[b, start_m:end_m, qh, :] = new_out.to(out.dtype) + softmax_lse[b, qh, start_m:end_m] = new_lse + return out, softmax_lse + + +def bar_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int, + step: int = 0, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + _triton_bar_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, bar_cnt, bar_idx, o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, + num_warps=4, num_stages=2, + ) + return o, lse + + +@triton.jit +def _triton_bar_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_cz, stride_ch, stride_cm, stride_cr, + stride_iz, stride_ih, stride_im, stride_in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + m_mask = offs_m < num_tokens + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + step * stride_cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + (step + 1) * stride_cr) + bar_idx_ptr = bar_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + start_m * stride_im + + if bar_l >= bar_r: + return + + o = tl.load(o_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + do = tl.load(do_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs, mask=m_mask[:, None], other=0.) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs, mask=m_mask, other=0.) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[:, None] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # Computer qk + qk = tl.where(m_mask[:, None] & n_mask[None, :], float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + cols[:, None] * stride_dvn, dv_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + cols[:, None] * stride_dkn, dk_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty), mask=m_mask[:, None]) + + +def bar_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dk: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + dv: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int, + deterministic: bool, + step: int = 0, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + dq = torch.zeros_like(q, dtype=torch.float32) if dq is None else dq.to(torch.float32) + dk = torch.zeros_like(k, dtype=torch.float32) if dk is None else dk.to(torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) if dv is None else dv.to(torch.float32) + _triton_bar_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + bar_cnt, bar_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + + +class MInferenceAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + + # Block Mask + out, softmax_lse = block_attn_fwd( + q, k, v, softmax_scale, + block_mask, + granularity=granularity, + causal=True, + ) + # Bar Mask + out, softmax_lse = bar_attn_fwd( + q, k, v, out, softmax_lse, softmax_scale, + bar_idx, bar_cnt, + granularity=granularity, + step=0, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors + # Block Mask + dq, dk, dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, ctx.softmax_scale, + block_mask, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + causal=True, + ) + + # Bar Mask + dq, dk, dv = bar_attn_bwd( + dout, q, k, v, out, dq, dk, dv, + softmax_lse, ctx.softmax_scale, + bar_idx, bar_cnt, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + step=0, + ) + return dq, dk, dv, None, None, None, None, None, None + + +def minference_flash_attn_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def minference_flash_attn_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def minference_flash_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnFunc.apply( + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def _build_mask_local( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + with torch.no_grad(): + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity, world_size, rank) + batch_size, num_tokens, num_heads, head_dim = q.shape + num_blocks = block_mask.shape[-1] + num_tokens_pad = num_blocks * granularity + # Block Mask + mask = block_mask.unsqueeze(3).unsqueeze(5).repeat((1, 1, 1, granularity, 1, granularity)) + mask = mask.reshape((batch_size, num_heads, num_tokens_pad, num_tokens_pad)) + # Bar Mask + for batch_idx in range(batch_size): + for head_idx in range(num_heads): + for row_idx in range(num_blocks): + row_u = row_idx * granularity + row_d = row_u + granularity + bar_l = bar_cnt[batch_idx, head_idx, row_idx, rank] + bar_r = bar_cnt[batch_idx, head_idx, row_idx, rank + 1] + for col_idx in bar_idx[batch_idx, head_idx, row_idx, bar_l:bar_r]: + mask[batch_idx, head_idx, row_u:row_d, col_idx] = True + # Causal Mask + arange = torch.arange(0, num_tokens_pad, dtype=torch.int32, device=q.device) + mask.masked_fill_(arange[None, None, :, None] < arange[None, None, None, :], False) + return mask[:, :, :num_tokens, :num_tokens] + + +def _torch_sparse_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + group_size = num_qo_heads // num_kv_heads + softmax_scale = head_dim ** -0.5 if softmax_scale is None else softmax_scale + mask = _build_mask_local(q, k, v_size, s_size, num_tokens, granularity) + k = k.repeat_interleave(group_size, dim=2) + v = v.repeat_interleave(group_size, dim=2) + p = torch.einsum('bmhd, bnhd -> bhmn', q * softmax_scale, k) + p = torch.where(mask, p, -torch.inf).to(torch.float32) + m = torch.max(p, dim=-1, keepdim=True).values.to(torch.float32) + p = torch.exp(p - m) + l = torch.sum(p, dim=-1, keepdim=True) + p = (p / l).to(q.dtype) + o = torch.einsum('bhmn, bnhd -> bmhd', p, v) + o = o.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) + if return_attn_probs: + lse = m + l.log() + return o, lse.squeeze(-1), None + return o + + +def _torch_sparse_attn_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + return _torch_sparse_attn_func( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def _torch_sparse_attn_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + return _torch_sparse_attn_func( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + ) diff --git a/minference/ops/minference_attn_triton.py b/minference/ops/minference_attn_triton.py new file mode 100644 index 0000000..0e76ba3 --- /dev/null +++ b/minference/ops/minference_attn_triton.py @@ -0,0 +1,1230 @@ +import os +import sys +import math +import ctypes + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +import triton +import triton.language as tl + +from typing import List, Tuple +from .utils import ( + build_index_local, _build_mask_local, convert_blockmask, + calc_index_local, convert_indices +) + + +@triton.jit +def _triton_block_attn_fwd_kernel( + Q, K, V, sm_scale, + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + if block_num <= 0: + return + + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(Q.type.element_ty) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs).to(tl.float32) + s_0 = tl.load(lse_ptrs) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty)) + tl.store(lse_ptrs, s, mask=overflow_mask) + +def triton_block_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + step: int = 0, + causal: bool = True, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = block_idx.shape[2] + + o = torch.zeros_like(q) + lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf + + _triton_block_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, + block_cnt, block_idx, + o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return o, lse + +@triton.jit +def _triton_block_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) + + +def triton_block_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + deterministic: bool, + step: int = 0, + causal: bool = True, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = block_idx.shape[2] + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k, dtype=torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) + + _triton_block_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + block_cnt, block_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + +@triton.jit +def _triton_block_bar_attn_fwd_kernel( + Q, K, V, sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_1cz, stride_1ch, stride_1cm, stride_1cr, + stride_1iz, stride_1ih, stride_1im, stride_1in, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) + bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + if (bar_l >= bar_r) and (block_num <= 0): + return + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(Q.type.element_ty) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Bar + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(n_mask[None, :], qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs).to(tl.float32) + s_0 = tl.load(lse_ptrs) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty)) + tl.store(lse_ptrs, s, mask=overflow_mask) + + +def block_bar_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + step: int = 0, + causal: bool = True, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + if o is None: + o = torch.zeros_like(q) + lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf + _triton_block_bar_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, bar_cnt, bar_idx, block_cnt, block_idx, o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return o, lse + + +@triton.jit +def _triton_block_bar_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_1cz, stride_1ch, stride_1cm, stride_1cr, + stride_1iz, stride_1ih, stride_1im, stride_1in, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) + bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + if (bar_l >= bar_r) and (block_num <= 0): + return + + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Bar + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[:, None] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # Computer qk + qk = tl.where(n_mask[None, :], float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + cols[:, None] * stride_dvn, dv_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + cols[:, None] * stride_dkn, dk_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) + + +def block_bar_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dk: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + dv: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + deterministic: bool, + step: int = 0, + causal: bool = True, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + dq = torch.zeros_like(q) if dq is None else dq + dk = torch.zeros_like(k, dtype=torch.float32) if dk is None else dk.to(torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) if dv is None else dv.to(torch.float32) + _triton_block_bar_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + bar_cnt, bar_idx, block_cnt, block_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + +def build_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + if type(v_size) is list: + assert len(v_size) == q.shape[2] + assert len(s_size) == q.shape[2] + v_idx, s_idx = calc_index_local(q, k, v_size, s_size, last_q_size=64) + else: + v_idx, s_idx = v_size, s_size + num_blocks = triton.cdiv(num_tokens, granularity) + block_mask, bar_idx, bar_cnt = convert_indices(v_idx, s_idx, world_size, rank, num_blocks, granularity) + block_mask = block_mask[rank] + return block_mask, bar_idx, bar_cnt + + +class MInferenceAttnTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + out, softmax_lse = block_bar_attn_fwd( + q, k, v, None, None, softmax_scale, + bar_idx, bar_cnt, block_idx, block_cnt, + granularity=granularity, + step=0, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + # Bar Mask + dq, dk, dv = block_bar_attn_bwd( + dout, q, k, v, out, None, None, None, + softmax_lse, ctx.softmax_scale, + bar_idx, bar_cnt, block_idx, block_cnt, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + step=0, + ) + + return dq, dk, dv, None, None, None, None, None, None + +def minference_flash_attn_triton_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnTritonFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def minference_flash_attn_triton_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnTritonFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def minference_flash_attn_triton_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnTritonFunc.apply( + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def _torch_sparse_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + group_size = num_qo_heads // num_kv_heads + softmax_scale = head_dim ** -0.5 if softmax_scale is None else softmax_scale + mask = _build_mask_local(q, k, v_size, s_size, num_tokens, granularity) + k = k.repeat_interleave(group_size, dim=2) + v = v.repeat_interleave(group_size, dim=2) + p = torch.einsum('bmhd, bnhd -> bhmn', q * softmax_scale, k) + p = torch.where(mask, p, -torch.inf).to(torch.float32) + m = torch.max(p, dim=-1, keepdim=True).values.to(torch.float32) + p = torch.exp(p - m) + l = torch.sum(p, dim=-1, keepdim=True) + p = (p / l).to(q.dtype) + o = torch.einsum('bhmn, bnhd -> bmhd', p, v) + o = o.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) + if return_attn_probs: + lse = m + l.log() + return o, lse.squeeze(-1), None + return o + + +def _torch_sparse_attn_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + return _torch_sparse_attn_func( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def _torch_sparse_attn_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + return _torch_sparse_attn_func( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def profile(func, inputs, num_warmups=10, num_iters=10): + torch.cuda.synchronize() + for _ in range(num_warmups): + func(*inputs) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_iters): + func(*inputs) + end.record() + torch.cuda.synchronize() + latency = start.elapsed_time(end) / num_iters + return latency + + +def print_compute_sparsity( + q: torch.Tensor, + k: torch.Tensor, + batch_size: int, + num_tokens: int, + num_qo_heads: int, + v_size: List[int], + s_size: List[int], + sparsity: float, + granularity: int = 128, +): + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + num_blocks = block_mask.shape[-1] + causal_blocks = batch_size * num_qo_heads * num_blocks * (num_blocks + 1) / 2 + avg_blocks = block_mask.sum(dim=-1, dtype=torch.float32).mean().item() + block_ratio = block_mask.sum(dtype=torch.float32).item() / causal_blocks + avg_bars = bar_cnt[..., -1].mean(dtype=torch.float32).item() + bar_ratio = avg_bars / (causal_blocks / (num_qo_heads * num_blocks * num_blocks) * num_tokens) + compute_sparsity = 1 - block_ratio - bar_ratio + print(f"Max {max(v_size)} V Lines => {avg_bars:.2f} / {num_tokens} = {100 * bar_ratio:.1f}% Bar Ratio") + print(f"Max {max(s_size)} S Lines => {avg_blocks:.2f} / {num_blocks} = {100 * block_ratio:.1f}% Block Ratio") + print(f"Mask Sparsity = {100 * sparsity:.1f}%, Compute Sparsity = {100 * compute_sparsity:.1f}% ({(1 - compute_sparsity):.3f})") + + +def test_minference_attn( + batch_size: int, + num_tokens: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + sparsity: float = 0.0, + granularity: int = 128, + check_results: bool = False, + profile_latency: bool = False, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = 'cuda', + seed: int = 2025, +): + assert not (num_tokens > 8192 and check_results) + torch.manual_seed(seed) + q = torch.randn((batch_size, num_tokens, num_qo_heads, head_dim), requires_grad=True, dtype=dtype, device=device) + kv = torch.randn((batch_size, num_tokens, 2, num_kv_heads, head_dim), requires_grad=True, dtype=dtype, device=device) + grad = torch.randn((batch_size, num_tokens, num_qo_heads, head_dim), requires_grad=False, dtype=dtype, device=device) + v_size = [int((1 - sparsity) * 0.5 * num_tokens)] * num_qo_heads + s_size = [int((1 - sparsity) * 0.5 * num_tokens)] * num_qo_heads + print(f"[B, Hq, Hk, N, D] = [{batch_size}, {num_qo_heads}, {num_kv_heads}, {num_tokens}, {head_dim}]") + print_compute_sparsity(q, kv[:, :, 0], batch_size, num_tokens, num_qo_heads, v_size, s_size, sparsity, granularity) + + def call_attn(attn, inputs, grad=None, backward=False): + o, lse, _ = attn(**inputs) + if backward: + q.grad = None + kv.grad = None + o.backward(grad) + dq = q.grad.clone() + dkv = kv.grad.clone() + return o, lse, dq, dkv + return o, lse + + sparse_inputs = { + 'q': q, 'kv': kv, 'v_size': v_size, 's_size': s_size, + 'granularity': granularity, 'return_attn_probs': True, + } + dense_inputs = { + 'q': q, 'kv': kv, + 'return_attn_probs': True, 'causal': True, + } + + if check_results: + o, lse, dq, dkv = call_attn(minference_flash_attn_triton_kvpacked_func, sparse_inputs, grad=grad, backward=True) + o_ref, lse_ref, dq_ref, dkv_ref = call_attn(_torch_sparse_attn_kvpacked_func, sparse_inputs, grad=grad, backward=True) + # import ipdb; ipdb.set_trace() + htol, stol = { torch.float16: (1e-2, 1e-3), torch.bfloat16: (5e-2, 1e-2) }[dtype] + torch.testing.assert_close(o, o_ref, atol=htol, rtol=htol) + torch.testing.assert_close(lse, lse_ref, atol=stol, rtol=stol) + torch.testing.assert_close(dq, dq_ref, atol=htol, rtol=htol) + torch.testing.assert_close(dkv, dkv_ref, atol=htol, rtol=htol) + + if profile_latency: + from flash_attn import flash_attn_kvpacked_func + flash_latency = profile(call_attn, [flash_attn_kvpacked_func, dense_inputs, grad, True]) + flash_fwd_latency = profile(call_attn, [flash_attn_kvpacked_func, dense_inputs, None, False]) + flash_bwd_latency = flash_latency - flash_fwd_latency + minfer_latency = profile(call_attn, [minference_flash_attn_triton_kvpacked_func, sparse_inputs, grad, True]) + minfer_fwd_latency = profile(call_attn, [minference_flash_attn_triton_kvpacked_func, sparse_inputs, None, False]) + minfer_idx_latency = profile(build_index_local, [q, kv[:, :, 0], v_size, s_size, num_tokens, granularity]) + minfer_bwd_latency = minfer_latency - minfer_fwd_latency + minfer_fwd_latency = minfer_fwd_latency - minfer_idx_latency + import pandas as pd + df = pd.DataFrame( + data=[ + [minfer_idx_latency, minfer_fwd_latency, minfer_bwd_latency], + [0, flash_fwd_latency, flash_bwd_latency], + [None, minfer_fwd_latency / flash_fwd_latency, minfer_bwd_latency / flash_bwd_latency] + ], + index=['MInfer', 'Flash', 'Ratio'], + columns=['Index', 'Forward', 'Backward'], + ).round(2) + print("-" * 64) + print(df) + + +if __name__ == '__main__': + print("=" * 64) + test_minference_attn(1, 131072, 4, 2, 128, sparsity=0.998, check_results=False, profile_latency=True) + \ No newline at end of file diff --git a/minference/ops/utils.py b/minference/ops/utils.py new file mode 100644 index 0000000..8e64722 --- /dev/null +++ b/minference/ops/utils.py @@ -0,0 +1,932 @@ +import os +import numpy as np +from typing import List + +import torch +import torch.distributed as dist + +import triton +import triton.language as tl + + +def set_seed(seed=42): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@triton.jit +def _triton_extract_kv_kernel( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + stride_lz, stride_ln, stride_lh, stride_ld, + stride_bz, stride_bn, stride_bh, stride_bd, + stride_iz, stride_ih, stride_in, + stride_cz, stride_ch, stride_cr, + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_n = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + min_n = tl.load(v_cnt_ptr + step * stride_cr) + max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) + start_n = start_n * BLOCK_N + end_n = start_n + BLOCK_N + if start_n >= max_n or end_n <= min_n: + return + + offs_d = tl.arange(0, BLOCK_D) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = (offs_n >= min_n) & (offs_n < max_n) + + v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + + # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens + idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens + k = tl.load(local_k_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) + v = tl.load(local_v_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) + tl.store(bar_k_ptr + offs_n[:, None] * stride_bn, k, mask=mask_n[:, None]) + tl.store(bar_v_ptr + offs_n[:, None] * stride_bn, v, mask=mask_n[:, None]) + + +def extract_kv( + local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + step: int, +): + batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape + _, num_tokens, num_kv_heads, _ = local_k.shape + block_N = 128 + block_D = head_dim + _triton_extract_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), + bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N=block_N, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + + +@triton.jit +def _triton_merge_kv_kernel( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + stride_lz, stride_ln, stride_lh, stride_ld, + stride_bz, stride_bn, stride_bh, stride_bd, + stride_iz, stride_ih, stride_in, + stride_cz, stride_ch, stride_cr, + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_n = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + min_n = tl.load(v_cnt_ptr + step * stride_cr) + max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) + start_n = start_n * BLOCK_N + end_n = start_n + BLOCK_N + if start_n >= max_n or end_n <= min_n: + return + + offs_d = tl.arange(0, BLOCK_D) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = (offs_n >= min_n) & (offs_n < max_n) + + v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + + # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens + idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens + k = tl.load(bar_k_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_k.type.element_ty) + v = tl.load(bar_v_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_v.type.element_ty) + tl.atomic_add(local_k_ptr + idx[:, None] * stride_ln, k, mask=mask_n[:, None], sem="relaxed") + tl.atomic_add(local_v_ptr + idx[:, None] * stride_ln, v, mask=mask_n[:, None], sem="relaxed") + + +def merge_kv( + local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + step: int, +): + batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape + _, num_tokens, num_kv_heads, _ = local_k.shape + block_N = 128 + block_D = head_dim + _triton_merge_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), + bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N=block_N, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + + +# triton.cdiv(world_size * num_blocks, BLOCK_N), num_heads, batch_size +# block_mask: [batch_size, num_heads, num_blocks_global] +@triton.jit +def _calc_block_mask_kernel( + s_idx, block_mask, + stride_sz, stride_sh, stride_sk, + stride_bz, stride_bh, stride_bn, + max_s_size, num_tokens, granularity, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + group_idx = tl.program_id(0) + + block_offs = tl.arange(0, BLOCK_N) + slash_offs = tl.arange(0, BLOCK_K) + + s_idx_ptr = s_idx + batch_idx * stride_sz + head_idx * stride_sh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx = group_idx * BLOCK_N + block_offs + + blocks = tl.zeros([BLOCK_N], dtype=tl.uint8) + for s_off in range(0, max_s_size, BLOCK_K): + s = tl.load(s_idx_ptr + (s_off + slash_offs) * stride_sk) + left = (num_tokens - granularity - s) // granularity + right = (num_tokens - 1 - s) // granularity + + # mask is generated by checking if a block's index falls between the calculated ranges + blocks |= tl.max((block_idx[None, :] >= left[:, None]) & (block_idx[None, :] <= right[:, None]), 0).to(tl.uint8) + + b_mask = (group_idx * BLOCK_N + block_offs) * granularity < num_tokens + tl.store(block_mask_ptr + (group_idx * BLOCK_N + block_offs) * stride_bn, blocks, mask=b_mask) + + +@triton.jit +def _striped_convert_indices_kernel( + last_row_mask, v_idx, v_cnt, + block_mask, bar_idx, bar_pos, bar_cnt, + stride_rz, stride_rh, stride_rn, + stride_vz, stride_vh, stride_vk, + stride_nz, stride_nh, stride_nt, + stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, + stride_iz, stride_ih, stride_im, stride_ik, + stride_cz, stride_ch, stride_cm, stride_ct, + max_v_size, num_blocks, granularity, world_size, rank, + BLOCK_N: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + block_idx_q_local = tl.program_id(0) + + block_idx_q_global = block_idx_q_local * world_size + rank + + num_tokens_local = num_blocks * granularity + num_blocks_global = world_size * num_blocks + shift = num_blocks_global - 1 - block_idx_q_global + + block_offs = tl.arange(0, BLOCK_N) + + last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh + v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh + v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm + bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm + + cnt_valid = 0 + cnt_all = 0 + v_off = 0 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + for step in range(world_size): + for block_off_k in range(0, num_blocks, BLOCK_N): + block_idx_k_local = block_off_k + block_offs + block_idx_k_global = (block_off_k + block_offs) * world_size + step + mask_local = tl.load( + last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, + mask=(block_idx_k_global + shift < num_blocks_global), + other=0, + ) + tl.store( + block_mask_ptr + block_idx_k_local * stride_bn, + mask_local, + mask=(block_idx_k_local < num_blocks), + ) + block_left = v_off + block_idx_k_local * granularity + block_right = block_left + granularity + max_blocks = block_idx_q_local + 1 if step <= rank else block_idx_q_local + v_max = v_off + min(block_off_k + BLOCK_N, max_blocks) * granularity + while v < v_max and cnt_all < max_v_size: + if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): + tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) + tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) + cnt_valid += 1 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + block_mask_ptr += stride_bt + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + v_off += num_tokens_local + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + +@triton.jit +def _zigzag_convert_indices_kernel( + last_row_mask, v_idx, v_cnt, + block_mask, bar_idx, bar_pos, bar_cnt, + stride_rz, stride_rh, stride_rn, + stride_vz, stride_vh, stride_vk, + stride_nz, stride_nh, stride_nt, + stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, + stride_iz, stride_ih, stride_im, stride_ik, + stride_cz, stride_ch, stride_cm, stride_ct, + max_v_size, num_blocks, granularity, world_size, rank, + BLOCK_N: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + block_idx_q_local = tl.program_id(0) + + if rank < world_size // 2: + revert_rank = rank * 2 + else: + revert_rank = (world_size - 1 - rank) * 2 + 1 + if block_idx_q_local < num_blocks // 2: + block_idx_q_global = revert_rank * (num_blocks // 2) + block_idx_q_local + else: + block_idx_q_global = (world_size * 2 - 1 - revert_rank) * (num_blocks // 2) + block_idx_q_local - (num_blocks // 2) + + num_blocks_global = world_size * num_blocks + shift = num_blocks_global - 1 - block_idx_q_global + + block_offs = tl.arange(0, BLOCK_N) + + last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh + v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh + v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm + bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm + + cnt_valid = 0 + cnt_all = 0 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + for step in range(world_size): + v_off = step * num_blocks * granularity + v_end = v_off + num_blocks * granularity + for block_off_k in range(0, num_blocks, BLOCK_N): + block_idx_k_local = block_off_k + block_offs + # assert BLOCK_N <= num_blocks // 2 + if block_off_k < num_blocks // 2: + v_off_global = step * (num_blocks // 2) * granularity + block_idx_k_global = step * (num_blocks // 2) + block_idx_k_local + else: + v_off_global = (world_size * 2 - 2 - step) * (num_blocks // 2) * granularity + block_idx_k_global = (world_size * 2 - 1 - step) * (num_blocks // 2) + block_idx_k_local - (num_blocks // 2) + mask_local = tl.load( + last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, + mask=(block_idx_k_global + shift < num_blocks_global), + other=0, + ) + tl.store( + block_mask_ptr + block_idx_k_local * stride_bn, + mask_local, + mask=(block_idx_k_local < num_blocks), + ) + # block_left = block_idx_k_global * granularity - v_off_global + v_off + # block_right = block_left + granularity + block_left = v_off + block_idx_k_local * granularity + block_right = block_left + granularity + v_max = (block_idx_q_global + 1) * granularity - v_off_global + v_off + while v < v_end and cnt_all <= max_v_size: + if v < v_max: + if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): + tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) + tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) + cnt_valid += 1 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + block_mask_ptr += stride_bt + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + +def convert_indices( + v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] + s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] + world_size: int, + rank: int, + num_blocks: int, + granularity: int, + num_tokens: int = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, +): + num_blocks_global = world_size * num_blocks + if num_tokens is None: + # Note that for each invokation of `convert_indices`, `num_tokens` is None and becomes the **global number of tokens** + num_tokens = num_blocks_global * granularity + batch_size, num_heads, max_v_size = v_idx.shape + batch_size, num_heads, max_s_size = s_idx.shape + last_row_mask = torch.zeros((batch_size, num_heads, num_blocks_global), dtype=torch.bool, device=s_idx.device) + + BLOCK_N, BLOCK_K = 128, 128 + assert max_s_size <= BLOCK_K * BLOCK_K, f"max_s_size={max_s_size} > BLOCK_K * BLOCK_K={BLOCK_K * BLOCK_K}" + _calc_block_mask_kernel[(triton.cdiv(num_blocks_global, BLOCK_N), num_heads, batch_size)]( + s_idx, last_row_mask, + s_idx.stride(0), s_idx.stride(1), s_idx.stride(2), + last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), + max_s_size, num_tokens, granularity, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + num_warps=4, num_stages=2, + ) + + block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) + bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) + bar_cnt = torch.empty((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) + v_cnt = torch.empty((batch_size, num_heads, world_size + 1), dtype=torch.int32, device=v_idx.device) + bar_pos = torch.zeros_like(bar_idx) + if zigzag_transform: + convert_indices_kernel = _zigzag_convert_indices_kernel + assert num_blocks % 2 == 0 + BLOCK_N = max(num_blocks // 2, 128) + else: + convert_indices_kernel = _striped_convert_indices_kernel + BLOCK_N = 128 + convert_indices_kernel[(num_blocks, num_heads, batch_size)]( + last_row_mask, v_idx, v_cnt, block_mask, bar_idx, bar_pos, bar_cnt, + last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3), block_mask.stride(4), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + max_v_size, num_blocks, granularity, world_size, rank, BLOCK_N=BLOCK_N, + num_warps=1, num_stages=1, + ) + # if zigzag_transform: + # if rank == 0: + # import ipdb; ipdb.set_trace() + # torch.save(block_mask, f'./output/data/block_mask_{rank}.pt') + # torch.save(bar_idx, f'./output/data/bar_idx_{rank}.pt') + # torch.save(bar_cnt, f'./output/data/bar_cnt_{rank}.pt') + # elif rank == 0: + # torch.save(block_mask, f'./output/data/block_mask.pt') + # torch.save(bar_idx, f'./output/data/bar_idx.pt') + # torch.save(bar_cnt, f'./output/data/bar_cnt.pt') + # bar_cnt = torch.zeros_like(bar_cnt) + return block_mask, bar_idx, bar_cnt, bar_pos, v_cnt + + +def _torch_convert_indices( + v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] + s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] + world_size: int, + rank: int, + num_blocks: int, + granularity: int, +): + batch_size, num_heads, max_v_size = v_idx.shape + num_tokens = world_size * num_blocks * granularity + block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) + bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) + bar_cnt = torch.zeros((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) + for batch_idx in range(batch_size): + for head_idx in range(num_heads): + for block_idx_q in range(num_blocks): + block_idx_q_global = block_idx_q * world_size + rank + cnt_all, cnt_valid = 0, 0 + for step in range(world_size): + for block_idx_k in range(block_idx_q + 1): + block_idx_k_global = block_idx_k * world_size + step + s_min = max((block_idx_q_global - block_idx_k_global - 1) * granularity, 0) + s_max = (block_idx_q_global - block_idx_k_global + 1) * granularity + flag = torch.any((s_idx[batch_idx, head_idx] > s_min) & (s_idx[batch_idx, head_idx] < s_max)) + block_mask[step, batch_idx, head_idx, block_idx_q, block_idx_k] = flag + v_min = (step * num_blocks + block_idx_k) * granularity + max_blocks = block_idx_q + 1 if step <= rank else block_idx_q + v_max = (step * num_blocks + min(block_idx_k + 1, max_blocks)) * granularity + while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_min: + cnt_all += 1 + while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_max: + if not flag: + bar_idx[batch_idx, head_idx, block_idx_q, cnt_valid] = \ + v_idx[batch_idx, head_idx, cnt_all] - step * num_blocks * granularity + cnt_valid += 1 + cnt_all += 1 + bar_cnt[batch_idx, head_idx, block_idx_q, step + 1] = cnt_valid + return block_mask, bar_idx, bar_cnt + + + +def sum_all_diagonal_matrix(mat: torch.Tensor): + b, h, m, n = mat.shape + + # Pads the matrix on left and right (on the last dimension) + mat_padded = torch.nn.functional.pad(mat, (m, m), "constant", 0.) # shape: [b, h, m, 2 * m + n] + # Change the strides + mat_strided = mat_padded.as_strided((b, h, m, m + n), (m * (2 * m + n) * h, m * (2 * m + n), 2 * m + n + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 2) # shape: [b, h, m + n] + return sum_diags[:, :, 1:].contiguous() + +def calc_index( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + last_q_size: int = 64, + sink_tokens: int = 30, + sliding_window: int = 100, + group: dist.group = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, + granularity: int = 128, +): + # TODO: adapt naturely striped inputs + # TODO: flex-prefill (top-P) + # TODO: reduce bubble + # TODO: support total_num_tokens % world_size != 0 + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): + flex_prefill = True + v_p = [x[0] for x in v_size] + v_size = [x[1] for x in v_size] + s_p = [x[0] for x in s_size] + s_size = [x[1] for x in s_size] + else: + flex_prefill = False + assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) + + max_v_size = min(triton.cdiv(max(v_size), 128), num_tokens // 128) * 128 + max_s_size = min(triton.cdiv(max(s_size), 128), num_tokens // 128) * 128 + + last_rank = world_size - 1 + if rank == last_rank: + last_q = q[:, -last_q_size:, :, :].detach().clone().reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)) + else: + last_q = torch.zeros((batch_size, last_q_size, num_kv_heads, num_qo_heads // num_kv_heads, head_dim), device=q.device, dtype=q.dtype) + + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking broadcast last_q from rank={last_rank}", flush=True) + dist.broadcast(last_q, src=last_rank, group=group, async_op=False) + + qk = torch.einsum('bmghd, bngd -> bghmn', last_q, k) * (k.shape[-1] ** -0.5) + qk = qk.reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) + + if rank == last_rank: + # Causal Mask, requires num_tokens // world_size >= last_q + arange = torch.arange(last_q_size, device=k.device) + mask = arange[None, None, :, None] >= arange[None, None, None, :] + qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) + if flex_prefill: # qk = torch.softmax(qk, dim=-1) / last_q_size + qk_max = torch.max(qk, dim=-1, keepdim=True).values + qk_max_list = [torch.empty_like(qk_max) for _ in range(world_size)] + dist.all_gather(qk_max_list, qk_max, group=group, async_op=False) + qk_max = torch.max(torch.stack(qk_max_list), dim=0).values + qk = torch.exp(qk - qk_max) + qk_sum = torch.sum(qk, dim=-1, keepdim=True) + qk_sum_list = [torch.empty_like(qk_sum) for _ in range(world_size)] + dist.all_gather(qk_sum_list, qk_sum, group=group, async_op=False) + qk_sum = torch.sum(torch.stack(qk_sum_list), dim=0) + qk /= (qk_sum * last_q_size) + + v_gather_rank = 0 + vertical = qk.sum(-2, keepdim=False) # [B, H, N_LOCAL] + if rank == 0 and not flex_prefill: + vertical[..., :sink_tokens] = torch.inf + if rank == v_gather_rank: + gathered_vertical = [torch.empty_like(vertical) for _ in range(world_size)] + else: + gathered_vertical = None + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking gather vertical to {v_gather_rank}", flush=True) + dist.gather(vertical, gathered_vertical, dst=v_gather_rank, group=group, async_op=False) + + if rank == v_gather_rank: + vertical: torch.Tensor = torch.cat(gathered_vertical, dim=-1) + if stripe_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, -1, world_size, granularity)) + vertical = vertical.swapaxes(2, 3) + vertical = vertical.reshape((batch_size, num_qo_heads, -1)) + elif zigzag_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, 2, world_size, -1)) + chunks = [] + for step in range(world_size): + chunks.append(vertical[:, :, 0, step]) + chunks.append(vertical[:, :, 1, world_size - 1 - step]) + vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) + + v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) + v_indices = v_topk.indices.to(torch.int32) + if flex_prefill: + v_cumsum = v_topk.values.cumsum_(dim=-1) + v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + v_size = torch.tensor(v_size, device=k.device)[None, :, None] + v_arange = torch.arange(max_v_size, device=k.device) + v_indices.masked_fill_(v_arange[None, None, :] >= v_size, num_tokens * world_size) + v_indices = v_indices.sort(dim=-1, descending=False).values + else: + v_indices = torch.empty((batch_size, num_qo_heads, max_v_size), dtype=torch.int32, device=k.device) + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking broadcast v_indices from rank={v_gather_rank}", flush=True) + dist.broadcast(v_indices, src=v_gather_rank, group=group, async_op=False) # async + + s_gather_rank = 0 + slash = sum_all_diagonal_matrix(qk) # shape: [B, H, N_LOCAL + LAST_Q_SIZE - 1] + if rank == world_size - 1 and not flex_prefill: + # -> index starting from the left bottom corner to right upper corner + # (sliding_window) from -(last_q_size-1) is the sliding window close to the main diagonal + slash[..., -(last_q_size - 1 + sliding_window):] = torch.inf + + + if rank == s_gather_rank: + gathered_slash = [torch.empty_like(slash) for _ in range(world_size)] + else: + gathered_slash = None + + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking gather slash to rank=0", flush=True) + dist.gather(slash, gathered_slash, dst=s_gather_rank, group=group, async_op=False) + + if rank == s_gather_rank: + slash = gathered_slash[0] + for next_slash in gathered_slash[1:]: + slash[..., -last_q_size + 1:] += next_slash[..., :last_q_size - 1] + slash = torch.cat((slash, next_slash[..., last_q_size - 1:]), dim=-1) + + # slash presents the sum of attention from 0-th to (num_tokens_global - last_q_size - 1), where 0 represents the diagonal at bottom left corner + slash = slash[..., :-last_q_size + 1] + s_topk = torch.topk(slash, max_s_size, -1, sorted=True) + + # s_indices contain indices starting from the right upper corner to left bottom corner + s_indices = (num_tokens * world_size - 1) - s_topk.indices.to(torch.int32) + if flex_prefill: + s_cumsum = s_topk.values.cumsum_(dim=-1) + s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + s_size = torch.tensor(s_size, device=k.device)[None, :, None] + s_arange = torch.arange(max_s_size, device=k.device) + s_indices.masked_fill_(s_arange[None, None, :] >= s_size, -1) + s_indices = s_indices.sort(dim=-1, descending=True).values + else: + s_indices = torch.empty((batch_size, num_qo_heads, max_s_size), dtype=torch.int32, device=k.device) + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking broadcast s_indices from rank={s_gather_rank}", flush=True) + dist.broadcast(s_indices, src=s_gather_rank, group=group, async_op=False) + + return v_indices.to(torch.int32), s_indices.to(torch.int32) + +def calc_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + last_q_size: int = 64, + sink_tokens: int = 30, + sliding_window: int = 100, + group: dist.group = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, + granularity: int = 128, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + + if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): + flex_prefill = True + v_p = [x[0] for x in v_size] + v_size = [x[1] for x in v_size] + s_p = [x[0] for x in s_size] + s_size = [x[1] for x in s_size] + else: + flex_prefill = False + assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) + + qk = torch.einsum( + f'bmghd, bngd -> bghmn', + q[:, -last_q_size:, :, :].reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)), + k, + ).reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) * (head_dim ** -0.5) + + arange = torch.arange(last_q_size, device=k.device) + mask = arange[None, None, :, None] >= arange[None, None, None, :] + qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) + if flex_prefill: + qk = torch.softmax(qk, dim=-1) / last_q_size + + max_v_size = min(max(v_size), num_tokens) + max_v_size = triton.cdiv(max_v_size, 128) * 128 + vertical = qk.sum(-2, keepdim=False) + if not flex_prefill: + vertical[..., :sink_tokens] = torch.inf + if stripe_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, -1, dist.get_world_size(group), granularity)) + vertical = vertical.swapaxes(2, 3) + vertical = vertical.reshape((batch_size, num_qo_heads, -1)) + elif zigzag_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, 2, dist.get_world_size(group), -1)) + chunks = [] + for step in range(dist.get_world_size(group)): + chunks.append(vertical[:, :, 0, step]) + chunks.append(vertical[:, :, 1, dist.get_world_size(group) - 1 - step]) + vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) + v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) + v_indices = v_topk.indices + if flex_prefill: + v_cumsum = v_topk.values.cumsum_(dim=-1) + v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + v_size = torch.tensor(v_size, device=k.device)[None, :, None] + + max_s_size = min(max(s_size), num_tokens) + max_s_size = triton.cdiv(max_s_size, 128) * 128 + slash = sum_all_diagonal_matrix(qk)[..., :-last_q_size + 1] + if not flex_prefill: + slash[..., -sliding_window:] = torch.inf + s_topk = torch.topk(slash, max_s_size, -1, sorted=True) + s_indices = (num_tokens - 1) - s_topk.indices + if flex_prefill: + s_cumsum = s_topk.values.cumsum_(dim=-1) + s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + s_size = torch.tensor(s_size, device=k.device)[None, :, None] + + v_arange = torch.arange(max_v_size, device=k.device) + v_idx = v_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) + v_idx.masked_fill_(v_arange[None, None, :] >= v_size, 2147483647) + v_idx = v_idx.sort(dim=-1, descending=False).values + + s_arange = torch.arange(max_s_size, device=k.device) + s_idx = s_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) + s_idx.masked_fill_(s_arange[None, None, :] >= s_size, -1) + s_idx = s_idx.sort(dim=-1, descending=True).values + + return v_idx, s_idx + + +def profile(func, inputs, num_warmups=100, num_iters=100): + torch.cuda.synchronize() + for _ in range(num_warmups): + func(*inputs) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_iters): + func(*inputs) + end.record() + torch.cuda.synchronize() + latency = start.elapsed_time(end) / num_iters + return latency + +def build_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + if type(v_size) is list: + assert len(v_size) == q.shape[2] + assert len(s_size) == q.shape[2] + v_idx, s_idx = calc_index_local(q, k, v_size, s_size, last_q_size=64) + else: + v_idx, s_idx = v_size, s_size + + num_blocks = triton.cdiv(num_tokens, granularity) + block_mask, bar_idx, bar_cnt, _, _ = convert_indices(v_idx, s_idx, world_size, rank, num_blocks, granularity) + block_mask = block_mask[rank] + return block_mask, bar_idx, bar_cnt + +def build_index( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, # num_tokens_local + granularity: int, + stripe_transform: bool = True, + zigzag_transform: bool = False, + group: dist.group = None, +): + """ + Input: (all inputs correspond to the local part for each rank) + q: shape [batch_size, num_tokens_local, num_qo_heads, head_dim] + k: shape [batch_size, num_tokens_local, num_kv_heads, head_dim] + v_size: shape [num_qo_heads] + s_size: shape [num_qo_heads] + num_tokens: number of tokens in the local part of QK + Returns: + block_mask: shape [world_size, batch_size, num_heads, num_blocks, num_blocks] + bar_idx: shape [batch_size, num_heads, num_blocks, max_v_size] + bar_cnt: shape [batch_size, num_heads, num_blocks, world_size + 1], each entry is the cumulative number of selected bars corresponding a rank + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + if isinstance(v_size, list): + v_idx, s_idx = calc_index( + q, k, v_size, s_size, last_q_size=64, group=group, + stripe_transform=stripe_transform, + zigzag_transform=zigzag_transform, + granularity=granularity + ) + else: + v_idx, s_idx = v_size, s_size + + num_blocks = triton.cdiv(num_tokens, granularity) # num_blocks_local + + # Note that block_mask is a 5D tensor with shape [world_size, batch_size, num_heads, num_blocks, num_blocks] + # with each block_mask[i] is to a mask corresponding the num_tokens_local x num_tokens_local matmul for each step + block_mask, bar_idx, bar_cnt, bar_pos, v_cnt = convert_indices( + v_idx, s_idx, world_size, rank, num_blocks, granularity, + stripe_transform=stripe_transform, + zigzag_transform=zigzag_transform, + ) + return block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt + + +def _build_mask_local( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + with torch.no_grad(): + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity, world_size, rank) + batch_size, num_tokens, num_heads, head_dim = q.shape + num_blocks = block_mask.shape[-1] + num_tokens_pad = num_blocks * granularity + # Block Mask + mask = block_mask.unsqueeze(3).unsqueeze(5).repeat((1, 1, 1, granularity, 1, granularity)) + mask = mask.reshape((batch_size, num_heads, num_tokens_pad, num_tokens_pad)) + # Bar Mask + for batch_idx in range(batch_size): + for head_idx in range(num_heads): + for row_idx in range(num_blocks): + row_u = row_idx * granularity + row_d = row_u + granularity + bar_l = bar_cnt[batch_idx, head_idx, row_idx, rank] + bar_r = bar_cnt[batch_idx, head_idx, row_idx, rank + 1] + for col_idx in bar_idx[batch_idx, head_idx, row_idx, bar_l:bar_r]: + mask[batch_idx, head_idx, row_u:row_d, col_idx] = True + # Causal Mask + arange = torch.arange(0, num_tokens_pad, dtype=torch.int32, device=q.device) + mask.masked_fill_(arange[None, None, :, None] < arange[None, None, None, :], False) + return mask[:, :, :num_tokens, :num_tokens] + + +def convert_blockmask( + blockmask: torch.Tensor, # [world_size, batch_size, num_heads, num_blocks, num_blocks] + block_size_M: int, + block_size_N: int, +): + ratio = block_size_M // block_size_N + original_shape = blockmask.shape + blockmask = blockmask.to(dtype=torch.uint8) + blockmask = blockmask.unsqueeze(-1).tile([1] * len(original_shape) + [ratio]).reshape((*original_shape[:-1], -1)) + + # now block_mask is [world_size, batch_size, num_heads, num_blocks, num_blocks * ratio] + nonzero_val, nonzero_idx = blockmask.sort(dim=-1, stable=True, descending=True) + + nonzero_rowcnt = blockmask.sum(dim=-1, dtype=torch.int32) + return nonzero_idx.contiguous().to(dtype=torch.int32), nonzero_rowcnt.contiguous() + +def compute_sr_flops( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + bar_cnt_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + q_len: int, + head_dim: int, + shift: bool, + fwd: bool=True, +): + num_blocks = triton.cdiv(q_len, granularity) + bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] + + if not shift: + total_num_blocks = bh * num_blocks * (num_blocks + 1) / 2 + else: + total_num_blocks = bh * num_blocks * (num_blocks - 1) / 2 + + block_ratio = block_mask_offset.sum(dtype=torch.float32).item() / total_num_blocks + bar_ratio = bar_cnt_offset.sum(dtype=torch.float32).item() / (granularity * total_num_blocks) + sparsity_ratio = 1 - block_ratio - bar_ratio + + block_flops = block_mask_offset.sum(dtype=torch.float32).item() * (granularity * granularity) * head_dim * 2 * 2 + bar_flops = bar_cnt_offset.sum(dtype=torch.float32).item() * granularity * head_dim * 2 * 2 + flops = block_flops + bar_flops + + if not fwd: + flops, block_flops, bar_flops = 2.5 * flops, 2.5 * block_flops, 2.5 * bar_flops + return block_ratio, bar_ratio, sparsity_ratio, block_flops, bar_flops, flops + +def compute_sr_by_heads( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + bar_cnt_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + q_len: int, + head_dim: int, + shift: bool, + fwd: bool=True, +): + num_heads = block_mask_offset.shape[1] + num_blocks = triton.cdiv(q_len, granularity) + if not shift: + total_num_blocks = num_blocks * (num_blocks + 1) / 2 + else: + total_num_blocks = num_blocks * (num_blocks - 1) / 2 + total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) + + block_ratio_by_heads = block_mask_offset.sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) / total_num_blocks_by_heads + bar_ratio_by_heads = bar_cnt_offset.sum(dim=-1).sum(0, dtype=torch.float32) / total_num_blocks_by_heads / granularity + sparsity_ratio_by_heads = 1 - block_ratio_by_heads - bar_ratio_by_heads + + return sparsity_ratio_by_heads + +def get_compute_sparsity( + q: torch.Tensor, + k: torch.Tensor, + batch_size: int, + num_tokens: int, + num_qo_heads: int, + v_size: List[int], + s_size: List[int], + granularity: int = 128, +): + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + num_blocks = block_mask.shape[-1] + causal_blocks = batch_size * num_qo_heads * num_blocks * (num_blocks + 1) / 2 + block_ratio = block_mask.sum(dtype=torch.float32).item() / causal_blocks + + avg_bars = bar_cnt[..., -1].mean(dtype=torch.float32).item() + bar_ratio = avg_bars / (causal_blocks / (num_qo_heads * num_blocks * num_blocks) * num_tokens) + + compute_sparsity = 1 - block_ratio - bar_ratio + + return compute_sparsity diff --git a/mtraining/.gitignore b/mtraining/.gitignore new file mode 100644 index 0000000..6fc96e7 --- /dev/null +++ b/mtraining/.gitignore @@ -0,0 +1,12 @@ +**/__pycache__/ +MTraining.egg-info/ +**.ipynb +**/prof_logs/ +.vscode/ +.pytest_cache/ +*.log +**/draft/ +output/ +**/ring_attn_comp_data/ +**/ring_attn_comp_data/ +**/ring_attn_pt_logs/ \ No newline at end of file diff --git a/mtraining/README.md b/mtraining/README.md new file mode 100644 index 0000000..6f08299 --- /dev/null +++ b/mtraining/README.md @@ -0,0 +1 @@ +# MTraining diff --git a/mtraining/__init__.py b/mtraining/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtraining/models/__init__.py b/mtraining/models/__init__.py new file mode 100644 index 0000000..e65a160 --- /dev/null +++ b/mtraining/models/__init__.py @@ -0,0 +1,29 @@ +from .phi3 import ( + Phi3Config, Phi3ForCausalLM, Phi3Attention, + apply_rotary_pos_emb, repeat_kv, + PHI_ATTN_FUNCS +) + +from .qwen2 import ( + Qwen2Config, Qwen2ForCausalLM, Qwen2Attention, + apply_rotary_pos_emb, repeat_kv, + QWEN_ATTN_FUNCS +) + + +MODEL_TO_ATTN_FUNC = { + "microsoft/Phi-3-mini-4k-instruct": PHI_ATTN_FUNCS, + "Qwen/Qwen2.5-3B": QWEN_ATTN_FUNCS +} + + +MODEL_ID_TO_MODEL_CLS = { + "microsoft/Phi-3-mini-4k-instruct": Phi3ForCausalLM, + "Qwen/Qwen2.5-3B": Qwen2ForCausalLM +} + +MODEL_ID_TO_PREFIX = { + "microsoft/Phi-3-mini-4k-instruct": "Phi3", + "Qwen/Qwen2.5-3B": "Qwen2", +} + diff --git a/mtraining/models/active_param_configs/attn_only.txt b/mtraining/models/active_param_configs/attn_only.txt new file mode 100644 index 0000000..52aef4b --- /dev/null +++ b/mtraining/models/active_param_configs/attn_only.txt @@ -0,0 +1 @@ +self_attn \ No newline at end of file diff --git a/mtraining/models/active_param_configs/qk_proj_only.txt b/mtraining/models/active_param_configs/qk_proj_only.txt new file mode 100644 index 0000000..1a23205 --- /dev/null +++ b/mtraining/models/active_param_configs/qk_proj_only.txt @@ -0,0 +1,2 @@ +self_attn.q_proj +self_attn.k_proj diff --git a/mtraining/models/phi3/__init__.py b/mtraining/models/phi3/__init__.py new file mode 100644 index 0000000..85b6f8f --- /dev/null +++ b/mtraining/models/phi3/__init__.py @@ -0,0 +1,5 @@ +from .modelling_phi import ( + Phi3Config, Phi3ForCausalLM, Phi3Attention, + apply_rotary_pos_emb, repeat_kv, + PHI_ATTN_FUNCS +) \ No newline at end of file diff --git a/mtraining/models/phi3/configuration_phi3.py b/mtraining/models/phi3/configuration_phi3.py new file mode 100644 index 0000000..7804010 --- /dev/null +++ b/mtraining/models/phi3/configuration_phi3.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json", + "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json", +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/mtraining/models/phi3/lc_config/configuration_phi3.py b/mtraining/models/phi3/lc_config/configuration_phi3.py new file mode 100644 index 0000000..7804010 --- /dev/null +++ b/mtraining/models/phi3/lc_config/configuration_phi3.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json", + "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json", +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/mtraining/models/phi3/lc_config_mini/configuration_phi3.py b/mtraining/models/phi3/lc_config_mini/configuration_phi3.py new file mode 100644 index 0000000..7804010 --- /dev/null +++ b/mtraining/models/phi3/lc_config_mini/configuration_phi3.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json", + "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json", +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/mtraining/models/phi3/modelling_phi.py b/mtraining/models/phi3/modelling_phi.py new file mode 100644 index 0000000..0965eaf --- /dev/null +++ b/mtraining/models/phi3/modelling_phi.py @@ -0,0 +1,1185 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi3/modular_phi3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +from .configuration_phi3 import Phi3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + +PHI_ATTN_FUNCS = ALL_ATTENTION_FUNCTIONS.copy() + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = PHI_ATTN_FUNCS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, config: Phi3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + elif self.rope_type == "longrope": + self._longrope_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _longrope_frequency_update(self, position_ids, device): + """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" + seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = self.config.original_max_position_embeddings + else: + original_max_position_embeddings = self.config.max_position_embeddings + if seq_len > original_max_position_embeddings: + if not hasattr(self, "long_inv_freq"): + self.long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + else: + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + + +PHI3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3PreTrainedModel(PreTrainedModel): + config_class = Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3Model(Phi3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Phi3Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +@add_start_docstrings( + """ + The Phi3 Model transformer with a sequence classification head on top (linear layer). + + [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3_START_DOCSTRING, +) +class Phi3ForSequenceClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Phi3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PHI3_START_DOCSTRING, +) +class Phi3ForTokenClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Phi3PreTrainedModel", + "Phi3Model", + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", +] \ No newline at end of file diff --git a/mtraining/models/phi3/modelling_phi_legacy.py b/mtraining/models/phi3/modelling_phi_legacy.py new file mode 100644 index 0000000..10d1c7d --- /dev/null +++ b/mtraining/models/phi3/modelling_phi_legacy.py @@ -0,0 +1,1568 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" PyTorch Phi-3 model.""" + +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_phi3 import Phi3Config + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor + + + +logger = logging.get_logger(__name__) + +# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements +# if is_flash_attn_2_available(): +_flash_supports_window_size = False +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError as error: + logger.warning( + f"`flash-attention` package not found, consider installing for better performance: {error}." + ) + if not _flash_supports_window_size: + logger.warning( + "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`." + ) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + +PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-mini-128k-instruct", + # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 +] + + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "longrope": + self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Phi3FlashAttention2(Phi3Attention): + """ + Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Phi3FlashAttention2 attention does not support output_attentions + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." + ) + raise ValueError("The current flash attention version does not support sliding window attention.") + + output_attentions = False + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=attn_dropout, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO @Arthur no longer copied from LLama after static cache +class Phi3SdpaAttention(Phi3Attention): + """ + Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Phi3Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHI3_ATTENTION_CLASSES = { + "eager": Phi3Attention, + "flash_attention_2": Phi3FlashAttention2, + "sdpa": Phi3SdpaAttention, +} + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHI3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3PreTrainedModel(PreTrainedModel): + config_class = Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3Model(Phi3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Phi3ForCausalLM(Phi3PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The [`Phi3Model`] with a sequence classification head on top (linear layer). + + [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs +class Phi3ForSequenceClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHI3_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs +class Phi3ForTokenClassification(Phi3PreTrainedModel): + def __init__(self, config: Phi3Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = Phi3Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/mtraining/models/qwen2/__init__.py b/mtraining/models/qwen2/__init__.py new file mode 100644 index 0000000..f9c07a0 --- /dev/null +++ b/mtraining/models/qwen2/__init__.py @@ -0,0 +1,7 @@ +from .modeling_qwen2 import ( + Qwen2ForCausalLM, Qwen2Attention, + apply_rotary_pos_emb, repeat_kv, + QWEN_ATTN_FUNCS +) + +from .configuration_qwen2 import Qwen2Config \ No newline at end of file diff --git a/mtraining/models/qwen2/configuration_qwen2.py b/mtraining/models/qwen2/configuration_qwen2.py new file mode 100644 index 0000000..1c85806 --- /dev/null +++ b/mtraining/models/qwen2/configuration_qwen2.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mtraining/models/qwen2/lc_config/configuration_qwen2.py b/mtraining/models/qwen2/lc_config/configuration_qwen2.py new file mode 100644 index 0000000..1c85806 --- /dev/null +++ b/mtraining/models/qwen2/lc_config/configuration_qwen2.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mtraining/models/qwen2/lc_config_mini/configuration_qwen2.py b/mtraining/models/qwen2/lc_config_mini/configuration_qwen2.py new file mode 100644 index 0000000..1c85806 --- /dev/null +++ b/mtraining/models/qwen2/lc_config_mini/configuration_qwen2.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mtraining/models/qwen2/mi_config/configuration_qwen2.py b/mtraining/models/qwen2/mi_config/configuration_qwen2.py new file mode 100644 index 0000000..78e7d61 --- /dev/null +++ b/mtraining/models/qwen2/mi_config/configuration_qwen2.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/mtraining/models/qwen2/mi_config/modeling_qwen2.py b/mtraining/models/qwen2/mi_config/modeling_qwen2.py new file mode 100644 index 0000000..253215f --- /dev/null +++ b/mtraining/models/qwen2/mi_config/modeling_qwen2.py @@ -0,0 +1,1490 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_qwen2 import Qwen2Config + +assert is_flash_attn_2_available() +from flash_attn import flash_attn_with_kvcache + +from mtraining_sparse_ops import get_minference_config, minference_flash_attn_func +MINFERENCE_CONFIG = get_minference_config("Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json") + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B" +_CONFIG_FOR_DOC = "Qwen2Config" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 +class Qwen2RotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[Qwen2Config] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2FlashAttention2(Qwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + if query_states.shape[1] == key_states.shape[1]: # Prefilling + v_size, s_size = MINFERENCE_CONFIG[self.layer_idx] + attn_output = minference_flash_attn_func(query_states, key_states, value_states, v_size, s_size) + else: + attn_output = flash_attn_with_kvcache(query_states, key_states, value_states) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2SdpaAttention(Qwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_ATTENTION_CLASSES = { + "eager": Qwen2Attention, + "flash_attention_2": Qwen2FlashAttention2, + "sdpa": Qwen2SdpaAttention, +} + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2 + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2 +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "model" + + # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2 + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/mtraining/models/qwen2/modeling_qwen2.py b/mtraining/models/qwen2/modeling_qwen2.py new file mode 100644 index 0000000..0b7175a --- /dev/null +++ b/mtraining/models/qwen2/modeling_qwen2.py @@ -0,0 +1,1136 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + QuestionAnsweringModelOutput, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" +_CONFIG_FOR_DOC = "Qwen2Config" + +QWEN_ATTN_FUNCS = ALL_ATTENTION_FUNCTIONS.copy() + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. # (B, H, S, D) + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) # (B, S, D) -> (B, 1, S, D) + sin = sin.unsqueeze(unsqueeze_dim) # (B, S, D) -> (B, 1, S, D) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = QWEN_ATTN_FUNCS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (B, D // 2, 1) + position_ids_expanded = position_ids[:, None, :].float() # (B, 1, S) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (B, S, D // 2) + emb = torch.cat((freqs, freqs), dim=-1) # (B, S, D) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, # example: torch.arange(0, 1024).unsqueeze(0) (shape: [B, S]) + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/mtraining/models/qwen2/vllm_sparse_qwen2.py b/mtraining/models/qwen2/vllm_sparse_qwen2.py new file mode 100644 index 0000000..4f28686 --- /dev/null +++ b/mtraining/models/qwen2/vllm_sparse_qwen2.py @@ -0,0 +1,465 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm import _custom_ops as ops + +from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers + +from flash_attn import flash_attn_func +from mtraining_sparse_ops import get_minference_config, minference_flash_attn_func +MINFERENCE_CONFIG = get_minference_config() + + +class Qwen2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.layer_idx = 0 + self.max_num_tokens = max_position + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + if attn_metadata.prefill_metadata and q.shape[0] < self.max_num_tokens: + # BATCH_SIZE == 1 + if kv_cache is not None: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + ops.reshape_and_cache_flash( + k.view(-1, self.num_kv_heads, self.head_dim), + v.view(-1, self.num_kv_heads, self.head_dim), + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.attn.kv_cache_dtype, + 1.0, + 1.0, + ) + v_size, s_size = MINFERENCE_CONFIG[self.layer_idx] + attn_output = minference_flash_attn_func( + q.reshape((1, -1, self.num_heads, self.head_dim)), + k.reshape((1, -1, self.num_kv_heads, self.head_dim)), + v.reshape((1, -1, self.num_kv_heads, self.head_dim)), + v_size, s_size, + ).reshape(q.shape) + else: + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2Model(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2DecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + for i, layer in enumerate(self.layers): + layer.self_attn.layer_idx = i + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen2ForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/mtraining/models/sparse_ops/.gitignore b/mtraining/models/sparse_ops/.gitignore new file mode 100644 index 0000000..f15bda0 --- /dev/null +++ b/mtraining/models/sparse_ops/.gitignore @@ -0,0 +1,7 @@ +__pycache__/ +*.egg-info/ +build/ +*.egg +configs/ +minference_attn.py +minference_sparse_index.py \ No newline at end of file diff --git a/mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py b/mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py new file mode 100644 index 0000000..6803d87 --- /dev/null +++ b/mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py @@ -0,0 +1,2 @@ +from .minference_config import get_minference_config +from .minference_attn import minference_flash_attn_func diff --git a/mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py b/mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py new file mode 100644 index 0000000..83c1682 --- /dev/null +++ b/mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py @@ -0,0 +1,23 @@ +import os +import json + + +CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs') +DEFAULT_CONFIG_FILE = "Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json" +if "MI_CONFIG" in os.environ: + DEFAULT_CONFIG_FILE = os.environ["MI_CONFIG"] + + +def get_minference_config(config_file: str = DEFAULT_CONFIG_FILE): + with open(os.path.join(CONFIG_DIR, config_file)) as f: + data = json.loads(f.read()) + config = [] + for layer_data in data: + v_size_list = [None] * len(layer_data) + s_size_list = [None] * len(layer_data) + for k, v in layer_data.items(): + assert v[0] in ['vertical_and_slash', 'flex_vertical_and_slash'] + v_size_list[int(k)] = v[1] + s_size_list[int(k)] = v[2] + config.append([v_size_list, s_size_list]) + return config diff --git a/mtraining/models/sparse_ops/setup.py b/mtraining/models/sparse_ops/setup.py new file mode 100644 index 0000000..c6b202e --- /dev/null +++ b/mtraining/models/sparse_ops/setup.py @@ -0,0 +1,27 @@ +import os +import shutil +from setuptools import setup, find_packages + + +setup_dir_path = os.path.dirname(__file__) +mtraining_path = os.path.dirname(os.path.dirname(setup_dir_path)) +setup_dir_path = os.path.join(setup_dir_path, "mtraining_sparse_ops") +cfg_dir_path = os.path.join(mtraining_path, "ops", "minfer", "configs") +op_dir_path = os.path.join(mtraining_path, "ops", "ring_attn", "core") + +shutil.copytree(cfg_dir_path, os.path.join(setup_dir_path, "configs"), dirs_exist_ok=True) + +with open(os.path.join(op_dir_path, "minference_sparse_index.py"), "r") as f: + index_code = f.read() +with open(os.path.join(setup_dir_path, "minference_sparse_index.py"), "w") as f: + f.write(index_code) +with open(os.path.join(op_dir_path, "minference_attn.py"), "r") as f: + attn_code = f.read() +with open(os.path.join(setup_dir_path, "minference_attn.py"), "w") as f: + f.write(attn_code.replace("MTraining.ops.ring_attn.core", "mtraining_sparse_ops")) + +setup( + name="mtraining_sparse_ops", # Name of your project + version="0.1.0", + packages=find_packages(), # Automatically discover all packages +) diff --git a/mtraining/requirements.txt b/mtraining/requirements.txt new file mode 100644 index 0000000..d852338 --- /dev/null +++ b/mtraining/requirements.txt @@ -0,0 +1,18 @@ +transformers==4.48.0 +datasets==2.20.0 +tensorboard +scikit-learn +matplotlib +seaborn +jieba +rouge +nltk +rouge_score +evaluate +triton==3.0.0 + +mosaicml-cli==0.5.34 +mosaicml-streaming==0.8.1 +sentencepiece==0.1.99 +tiktoken==0.7.0 +zstandard==0.22.0 diff --git a/mtraining/setup.py b/mtraining/setup.py new file mode 100644 index 0000000..19e0346 --- /dev/null +++ b/mtraining/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup, find_packages + +setup( + name="MTraining", # Name of your project + version="0.1.0", + packages=find_packages(), # Automatically discover all packages + install_requires=[], # List dependencies if any (or use requirements.txt) + url="https://github.com/HalberdOfPineapple/MTraining", # Repository URL if applicable + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.10", # Specify the Python version +) diff --git a/mtraining/setup.sh b/mtraining/setup.sh new file mode 100755 index 0000000..863ee5b --- /dev/null +++ b/mtraining/setup.sh @@ -0,0 +1,34 @@ +#!/usr/bin/bash +set -e # Exit on first error +BASE_DIR="$(cd "$(dirname "$0")" && pwd)" +echo $BASE_DIR +PIP="$(which pip)" + +if command -v nvidia-smi +then + # assume base image: amlt-sing/acpt-torch2.3.1-py3.10-cuda12.1-ubuntu22.04 + $PIP install ninja cmake wheel pybind11 + $PIP install --no-cache-dir torch==2.3.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 + $PIP install -r "${BASE_DIR}/requirements.txt" + $PIP install git+https://github.com/microsoft/nnscaler.git@2368540417bc3b77b7e714d3f1a0de8a51bb66e8 + $PIP install "rotary-emb @ git+https://github.com/Dao-AILab/flash-attention.git@9356a1c0389660d7e231ff3163c1ac17d9e3824a#subdirectory=csrc/rotary" + $PIP install "block_sparse_attn @ git+https://github.com/HalberdOfPineapple/flash-attention.git@block-sparse" +elif command -v rocm-smi +then + $PIP install ninja cmake wheel pybind11 + $PIP install --no-cache-dir --pre torch==2.3.1+rocm6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0 + $PIP install git+https://github.com/OpenAI/triton.git@e192dba#subdirectory=python + $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 + $PIP install -r "${BASE_DIR}/requirements.txt" + $PIP install git+https://github.com/microsoft/nnscaler.git@2368540417bc3b77b7e714d3f1a0de8a51bb66e8 +else + echo "ERROR: both nvidia-smi and rocm-smi not found" + exit 1 +fi + +# Get the path to nnscaler and write its path to PYTHONPATH in ~/.profile +NNSCALER_HOME=$(python -c "import nnscaler; print(nnscaler.__path__[0])") +echo "export NNSCALER_HOME=${NNSCALER_HOME}" >> ~/.profile +echo "export PYTHONPATH=${NNSCALER_HOME}:\${PYTHONPATH}" >> ~/.profile +source ~/.profile \ No newline at end of file diff --git a/mtraining/train.py b/mtraining/train.py new file mode 100644 index 0000000..8993d04 --- /dev/null +++ b/mtraining/train.py @@ -0,0 +1,625 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import os +import yaml +import torch +import shutil +import argparse +import numpy as np +import torch.distributed as dist + +from typing import Dict, List, Optional +from datasets import load_from_disk +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling +from transformers.modeling_utils import PreTrainedModel + +from nnscaler.utils import set_default_logger_level +from nnscaler.cli.trainer_args import ( + CheckpointConfig, + DatasetConfig, + HookMapConfig, + ModelConfig, + OptimizerConfig, + DataloaderConfig, + LogConfig, + DatasetSamplerConfig, +) +from nnscaler.parallel import ComputeConfig +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW +from nnscaler.cli.loggers.tensorboard import TensorBoardLogger + +from .trainer import CustomTrainer as Trainer, CustomTrainerArgs as TrainerArgs +from .utils import chunk_linear_cross_entropy, get_tokenizer, aggregate_outputs_fn, get_resume_path +from MTraining.ops.minfer import ExprMInferConfig as MInferenceConfig, ExprMInference as MInference +from MTraining.ops import AttnType, overwrite_attn_implementation, load_moba_config, MoBAConfig +from MTraining.models import MODEL_TO_ATTN_FUNC, MODEL_ID_TO_MODEL_CLS, MODEL_ID_TO_PREFIX +from .utils.paths import ( + MINFER_CONFIG_DIR, SPARSE_PATTERN_CONFIG_DIR, SPARSE_HEAD_MAP_DIR, + update_expr_data_save_path +) +from MTraining.utils.train_utils import freeze_model_params, load_comm_profile_data +from MTraining.utils.expr_data import update_expr_data + +import logging +logger = logging.getLogger(__name__) +set_default_logger_level('INFO') + +IGNORE_IDX = -100 + + +def init_by_attn_type(model_id: str, attn_type: AttnType): + attn_dict = MODEL_TO_ATTN_FUNC[model_id] + + if attn_type == AttnType.BASELINE: + print(f"{__name__} | Using Baseline Model...") + elif attn_type == AttnType.FLEX_PREFILL: + print(f"{__name__} | Using FlexPrefill-equipped Model ...") + elif attn_type == AttnType.RING_ATTN: + print(f"{__name__} | Using Ring Attention Zigzag-equipped Model ...") + elif attn_type == AttnType.RING_ATTN_STRIPE: + print(f"{__name__} | Using Ring Attention Stripe-equipped Model ...") + elif attn_type == AttnType.MF_MB: + print(f"{__name__} | Using MInference-equipped Model ...") + elif attn_type == AttnType.MOBA: + print(f"{__name__} | Using MoBA-equipped Model ...") + elif attn_type == AttnType.ZIGZAG_MOBA: + print(f"{__name__} | Using ZigZag MoBA-equipped Model ...") + elif attn_type == AttnType.XATTN: + print(f"{__name__} | Using XAttention-equipped Model ...") + else: + raise ValueError(f"Invalid attn_type: {attn_type}") + + overwrite_attn_implementation(attn_dict, attn_type) + + +class BaselineModel(torch.nn.Module): + def __init__( + self, + model_id, + config_path: str=None, + # merged_ckpt_path: str=None, + active_param_config_name: str=None + ): + super().__init__() + model_cls: PreTrainedModel = MODEL_ID_TO_MODEL_CLS[model_id] + + if not config_path: + self.model = model_cls.from_pretrained( + model_id, + attn_implementation='flash_attention_2' + ) + else: + model_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + model_config._attn_implementation = 'flash_attention_2' + self.model = model_cls.from_pretrained( + model_id, + config=model_config, + ) + + if active_param_config_name: + freeze_model_params(self.model, active_param_config_name) + + print(f'{__class__.__name__} Self-Attention Class: {self.model.model.layers[0].self_attn.__class__.__name__}') + + def forward(self, samples): + with torch.autocast(device_type="cuda", dtype=self.model.config.torch_dtype): + outputs = self.model.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + hidden_states = outputs[0] + losses = chunk_linear_cross_entropy(hidden_states, self.model.lm_head.weight, samples['target'], IGNORE_IDX, 1024) + loss = torch.sum(losses) + + return loss, loss.data, samples['ntokens'], samples['nsentences'] + +class MInferModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + minfer_config: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + + # -------------------------------------------- + # MInference implementation: "fa", "stripe" + minfer_implementation: str = minfer_config.pop('implementation', 'fa') + + # -------------------------------------------- + # Sparse iteratio, layer and head control + start_sparse_iter: int = minfer_config.pop('start_sparse_iter', 0) + start_sparse_layer: int = minfer_config.pop('start_sparse_layer', 0) + adaptive_sparse: bool = minfer_config.pop('adaptive_sparse', False) + sparse_head_map_name: str = minfer_config.pop('sparse_head_map_name', None) + if sparse_head_map_name is not None: + active_sparse_map_path: str = os.path.join( + SPARSE_HEAD_MAP_DIR, + f'{sparse_head_map_name}.npy', + ) + print(f"{__name__} | Active Sparse Head Map Path: {active_sparse_map_path}") + active_sparse_map: np.ndarray = np.load(active_sparse_map_path) + active_sparse_map: List[List[bool]] = active_sparse_map.tolist() + else: + active_sparse_map: List[List[bool]] = None + + # ---------------------------------------------- + # Ring Attention specific + granularity: int = minfer_config.pop('granularity', 128) + + # -------------------------------------------- + # Standard MInference Setup + minfer_attn_type = minfer_config.pop('attn_type', 'minference') + minfer_config['config_path'] = os.path.join( + SPARSE_PATTERN_CONFIG_DIR, + f'{minfer_config.pop("pattern_config_name")}.json', + ) + print(f"{__name__} | MInference Pattern Config Path: {minfer_config['config_path']}") + minfer = MInference( + attn_type=minfer_attn_type, + model_name=model_id, + **minfer_config, + ) + minfer_config: MInferenceConfig = minfer.config + + # -------------------------------------------- + # We still need to attach the function object to the model + # otherwise the states of the function will be lost as nnscaler will only load the model from file + # but not call this procedure again + from ops.minfer_func import MInferAttnFunc + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.minfer_attn_func = MInferAttnFunc() + m.minfer_attn_func.init_minfer_params( + config_path=minfer_config.config_path, + minfer_implementation=minfer_implementation, + + start_sparse_iter=start_sparse_iter, + start_sparse_layer=start_sparse_layer, + adaptive_sparse=adaptive_sparse, + active_sparse_map=active_sparse_map, + + granularity=granularity, + ) + self.model.apply(update_module) + +class FlexPrefillModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + attn_config: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + + from ops.flex_prefill_func import FlexPrefillFunc + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.flex_prefill_attn_func = FlexPrefillFunc(attn_config) + self.model.apply(update_module) + +class XAttnModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + xattn_params: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + + # -------------------------------------------- + implementation: str = xattn_params.pop('implementation', 'fa') + granularity: int = xattn_params.pop('granularity', 128) + + # -------------------------------------------- + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.granularity = granularity + m.xattn_params = xattn_params + m.implementation = implementation + self.model.apply(update_module) + + +class MoBAModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + moba_config_dict: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + + # -------------------------------------------- + print(f"MoBAConfig: {moba_config_dict}") + moba_config = MoBAConfig(**moba_config_dict) + moba_topk, moba_chunk_size = moba_config.moba_topk, moba_config.moba_chunk_size + + # -------------------------------------------- + # We still need to attach the function object to the model + # otherwise the states of the function will be lost as nnscaler will only load the model from file + # but not call this procedure again + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.moba_topk = moba_topk + m.moba_chunk_size = moba_chunk_size + self.model.apply(update_module) + +ATTN_TO_MODEL = { + AttnType.BASELINE: BaselineModel, + AttnType.FLEX_PREFILL: FlexPrefillModel, + AttnType.MF_MB: MInferModel, + AttnType.RING_ATTN: BaselineModel, + AttnType.RING_ATTN_STRIPE: BaselineModel, + AttnType.MOBA: MoBAModel, + AttnType.ZIGZAG_MOBA: MoBAModel, + AttnType.XATTN: XAttnModel, +} + + +def load_minfer_config(minfer_config_name: str) -> MInferenceConfig: + minfer_config_path = os.path.join(MINFER_CONFIG_DIR, f'{minfer_config_name}.yaml') + if not os.path.exists(minfer_config_path): + print(f"{__name__} | MInference config {minfer_config_name} not found in {minfer_config_path}. Use empty minfer config") + minfer_config = {} + else: + print(f"{__name__} | MInference config {minfer_config_name} found in {minfer_config_path}") + with open(minfer_config_path, 'r') as f: + minfer_config = yaml.safe_load(f) + print('-' * 20) + print("MInference Config:") + print(minfer_config) + print('-' * 20) + + return minfer_config + +def build_model_args(args, minfer_config: MInferenceConfig) -> Dict: + model_args = { + 'model_id': args.model_id, + 'config_path': args.model_config_path, + "active_param_config_name": args.active_param_config_name, + } + if args.attn_type == AttnType.MF_MB: + model_args['minfer_config'] = minfer_config + elif args.attn_type == AttnType.FLEX_PREFILL: + model_args['attn_config'] = minfer_config + elif args.attn_type == AttnType.XATTN: + model_args['xattn_params'] = minfer_config + elif args.attn_type == AttnType.MOBA or args.attn_type == AttnType.ZIGZAG_MOBA: + model_args['moba_config_dict'] = minfer_config + + return model_args + + +def main(args): + update_expr_data_save_path(args.attn_save_path, args.ckpt_save_dir, args.compile_save_path) + update_expr_data(args) + + local_rank = int(os.environ["LOCAL_RANK"]) + if local_rank == 0: + load_comm_profile_data(args) + + init_by_attn_type(args.model_id, args.attn_type) + minfer_config = load_minfer_config(args.minfer_config_name) + + # broadcast_strategy = 'all' if args.run_mode == 'run' else 'none' + broadcast_strategy = 'all' + # --------------------------------- + # Compute config + if args.run_mode == 'compile': + if args.runtime_ngpus is None: + raise ValueError('runtime_ngpus must be specified in compile mode') + runtime_ngpus = args.runtime_ngpus + elif args.run_mode == 'run': + world_size = int(os.getenv('WORLD_SIZE')) + if args.runtime_ngpus is None: + runtime_ngpus = world_size + else: + if args.runtime_ngpus != world_size: + raise ValueError(f'runtime_ngpus ({args.runtime_ngpus}) must match the number of GPUs in run mode ({world_size})') + runtime_ngpus = args.runtime_ngpus + + if runtime_ngpus % args.plan_ngpus != 0: + raise ValueError('runtime_ngpus must be a multiple of plan_ngpus') + + scaling_factor: int = runtime_ngpus // args.plan_ngpus + grad_accu_step: int = args.global_batch_size // (args.micro_batch_size * scaling_factor) + + model_prefix = MODEL_ID_TO_PREFIX[args.model_id] + pas_config = { + 'recompute_modules': f'{model_prefix}DecoderLayer', + } + if args.mem_constraint > 0: pas_config['mem_constraint'] = args.mem_constraint + compute_config = ComputeConfig( + plan_ngpus=args.plan_ngpus, + trace_strategy=args.trace_strategy, + runtime_ngpus=runtime_ngpus, + constant_folding=True, + use_zero=True, + use_end2end=True, + # autodist config: + pas_config=pas_config, + ) + + # --------------------------------- + ## Setup Dataset ## + dataset = load_from_disk(args.dataset_path) + tokenizer = get_tokenizer(args.model_id) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + def collate_fn(samples): + if len(samples) == 0: + return {} + + mini_batch = data_collator(samples) + _mini_batch = {} + + src_tokens = mini_batch.pop('input_ids') + seq_len = src_tokens.size(-1) + _mini_batch['src_tokens'] = src_tokens + + shift_labels = mini_batch['labels'][..., 1:] + _mini_batch['labels'] = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', IGNORE_IDX).contiguous() + + return { + "nsentences": len(samples), + "ntokens": len(samples) * seq_len, + "net_input": _mini_batch, + "target": _mini_batch.pop('labels'), + } + dataset_config = DatasetConfig( + type=(lambda split: dataset), + train_args={'split': 'train'}, + ) + dataloader_config = DataloaderConfig( + train_args={ + 'collate_fn': collate_fn, + 'drop_last': True, + }, + ) + sampler_config = DatasetSamplerConfig( + # default class: torch.utils.data.distributed.DistributedSampler + train_args={ + 'shuffle': True, + 'seed': args.seed, + }, + ) + + # --------------------------------- + # Model Config + model_args = build_model_args(args, minfer_config) + model_config = ModelConfig( + type=ATTN_TO_MODEL[args.attn_type], + args=model_args, + ) + + # --------------------------------- + # optimizer hyperparameters are from YaRN + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={ + 'lr': 2e-5, + 'betas': (0.9, 0.95), + 'weight_decay': 0.0, + 'fused': True + }, + clip_gnorm=1.0, + loss_reduction='sum', + grad_reduction='per-token-mean', + aggregate_outputs_fn=aggregate_outputs_fn, + ) + + + # --------------------------------- + # Checkpoint Config + checkpoint_config = CheckpointConfig( + save_dir=args.ckpt_save_dir if args.ckpt_save_dir else f'./checkpoints_{args.name}', + every_n_epochs=args.ckpt_n_epoch, + every_n_train_steps=args.ckpt_n_step, + save_type='deduped', + # resume_from=(args.resume_from or 'last') if args.check_resume else None, + resume_from=args.resume_from, + ) + + # --------------------------------- + # Log Config + log_config = LogConfig( + type=TensorBoardLogger, + args={ + 'name': args.name, + 'root_dir': args.tf_log_dir or f'./runs_{args.name}', + }, + ) + + # --------------------------------- + trainer_args = TrainerArgs( + global_batch_size=args.global_batch_size, + micro_batch_size=args.micro_batch_size, + grad_accumulation_steps=grad_accu_step, + + pas_policy='autodist', + precision='bf16', + seed=args.seed, + gen_reuse=args.reuse_type, + + gen_savedir=args.compile_save_path, + instance_name=args.name, + run_mode=args.run_mode, + max_epochs=args.n_epochs, + max_train_steps=args.n_iter, + enable_progress_bar=not args.disable_progressbar, + + compute_config=compute_config, + model=model_config, + optimizer=optimizer_config, + dataset=dataset_config, + dataloader=dataloader_config, + checkpoint=checkpoint_config, + log=[log_config], + + broadcast_strategy=broadcast_strategy, + dataset_sampler=sampler_config, + + transfer_config={ + "transfer_config_dir": args.transfer_config_dir, + "transfer_force": args.transfer_force, + }, + merged_ckpt_path=args.resume_merged_ckpt, + ) + + trainer = Trainer( + train_args=trainer_args, + save_data_steps=args.attn_save_step, + enable_prof=args.enable_prof, + ) + trainer.run() + +def print_args(args: argparse.Namespace): + print("=" * 80) + print(f"Start Experiment:\t{args.name}") + print(f"Seed:\t{args.seed}") + print(f"Reuse Type:\t{args.reuse_type}") + print(f"Run Mode:\t{args.run_mode}") + print(f"Total number of GPUs:\t{args.runtime_ngpus}") + print(f"GPU unit size:\t{args.plan_ngpus}") + print(f"Model ID:\t{args.model_id}") + + print('-' * 40) + if args.n_iter: + print(f"Number of Iterations:\t{args.n_iter} (number of tokens: {args.n_iter * args.global_batch_size * args.seq_len})") + else: + print(f"Number of Epochs:\t{args.n_epochs}") + + + print(f'Global Batch Size:\t{args.global_batch_size}') + print(f'Micro Batch Size:\t{args.micro_batch_size}') + + scaling_factor = args.runtime_ngpus // args.plan_ngpus + grad_accu_step = args.global_batch_size // (args.micro_batch_size * scaling_factor) + print(f"Scaling Factor (INFERRED):\t{scaling_factor}") + print(f"Gradient Accumulation Steps (INFERRED):\t{grad_accu_step}") + print(f"Save Attention Data Every {args.attn_save_step} Steps") + + print('-' * 40) + print(f"Model Config Path:\t{args.model_config_path}") + print(f"Dataset path:\t{args.dataset_path}") + print(f'MInferenece Config Name:\t{args.minfer_config_name}') + print(f"Compile Save Path:\t{args.compile_save_path}") + print(f"Attention Save Path:\t{args.attn_save_path}") + print(f"Tensorboard Log Path:\t{args.tf_log_dir}") + print(f"Checkpoint Save Path:\t{args.ckpt_save_dir}") + print(f"Resume from Checkpoint:\t{args.check_resume}") + print(f"Path to the checkpoint to resume from:\t{args.resume_from}") + print(f"Path to the merged checkpoint to resume from:\t{args.resume_merged_ckpt}") + + print(f"Enable profiling: {args.enable_prof}") + print(f"Trace Strategy:\t{args.trace_strategy}") + if args.transfer_config_dir: + print(f"Transfer Configs from another experiment:\t{args.transfer_config_dir}") + print(f"Force Transfer Configs:\t{args.transfer_force}") + + if args.active_param_config_name: + print(f"Active Param Config Name:\t{args.active_param_config_name}") + + if args.ckpt_n_step: + print(f"Checkpoint Save Every {args.ckpt_n_step} Steps") + else: + print(f"Checkpoint Save Every {args.ckpt_n_epoch} Epochs") + print("=" * 80, flush=True) + +if __name__ == '__main__': + ## Parse Args ## + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=0, help='random seed') + parser.add_argument('--name', type=str, default='phi-grad', help='name of the experiment') + parser.add_argument('--seq_len', type=int, default=131072, help='sequence length') + parser.add_argument('--attn_type', type=str, default=AttnType.BASELINE, choices=AttnType.__dict__.values(), help='minference type') + parser.add_argument('--reuse_type', type=str, default='match', choices=['match', 'override', 'moo', 'graph'], help='reuse type') + parser.add_argument('--run_mode', type=str, default='run', choices=['run', 'compile'], help='run or compile') + parser.add_argument('--trace_strategy', type=str, default='cuda_run_cpu_offload', + choices=['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'], + help='trace strategy') + parser.add_argument('--plan_ngpus', type=int, required=True, help='specify the scale unit size') + parser.add_argument('--runtime_ngpus', type=int, required=True, help='specify the number of GPUs to use') + + parser.add_argument('--n_iter', type=int, default=0, help='Number of iterations') + parser.add_argument('--n_epochs', type=int, default=0, help='Number of epochs') + parser.add_argument('--nB_tokens', type=int, default=0, help='Number of tokens (in B) to process') + parser.add_argument('--global_batch_size', type=int, default=4, help='global batch size') + parser.add_argument('--micro_batch_size', type=int, default=1, help='micro batch size') + parser.add_argument('--mem_constraint', type=int, default=0, help='memory constraint') + + parser.add_argument('--model_id', type=str, default='microsoft/Phi-3-mini-4k-instruct', help='transformers model id') + parser.add_argument('--model_config_path', type=str, default=None, help='path to the model config') + parser.add_argument('-s', '--attn_save_step', type=int, default=1, help='Save attention data every n steps') + + parser.add_argument('--minfer_config_name', type=str, default=None, help='Name of Minference config file') + parser.add_argument('--compile_save_path', type=str, default='./.nnscaler', help='path to save compiled code') + parser.add_argument('--attn_save_path', type=str, default=None, help='path to save attention data') + parser.add_argument('--tf_log_dir', type=str, default=None, help='path to save tensorboard logs') + parser.add_argument('--dataset_path', type=str, default=None, help='path to the dataset') + parser.add_argument('--check_resume', action='store_true', help='whether to resume from checkpoint') + parser.add_argument('--resume_from', type=str, default=None, help='path to the checkpoint to resume from') + parser.add_argument('--resume_merged_ckpt', type=str, default=None, help='path (dir) to the merged checkpoint to resume from') + + parser.add_argument('--enable_prof', action='store_true', help='enable profiling') + + parser.add_argument('--ckpt_save_dir', type=str, default=None, help='path to save checkpoints') + parser.add_argument('--ckpt_n_epoch', type=int, default=1, help='save checkpoint every n epochs') + parser.add_argument('--ckpt_n_step', type=int, default=0, help='save checkpoint every n steps') + parser.add_argument('--transfer_config_dir', type=str, default="none", help='path to transfer configs from another experiment') + parser.add_argument('--transfer_force', action='store_true', help='force transfer configs') + parser.add_argument('--active_param_config_name', type=str, default=None, help='path to the active param list') + + parser.add_argument('-p', '--disable_progressbar', action='store_true', help='transformers model id',) + + args = parser.parse_args() + + # ------------------------------------------------- + # Preprocessing args + if args.ckpt_n_epoch <= 0: args.ckpt_n_epoch = None + if args.ckpt_n_step <= 0: args.ckpt_n_step = None + + if args.nB_tokens > 0: + args.n_iter = args.nB_tokens * 1e9 // args.global_batch_size // args.seq_len + 1 + args.n_epochs = 0 + if args.n_iter <= 0: args.n_iter = None + if args.n_epochs <= 0: args.n_epochs = None + + if args.minfer_config_name is None or args.minfer_config_name.lower() == 'none': args.minfer_config_name = None + if args.transfer_config_dir.lower() == 'none': args.transfer_config_dir = None + if args.active_param_config_name.lower() == 'none': args.active_param_config_name = None + + # set a new field of args 'args.orig_resume_from' to store the original resume_from value + args.orig_resume_from = args.resume_from + args.resume_from = get_resume_path( + args.check_resume, args.resume_from, args.ckpt_save_dir, args.runtime_ngpus + ) + + + print_args(args) + main(args) \ No newline at end of file diff --git a/mtraining/trainer.py b/mtraining/trainer.py new file mode 100644 index 0000000..9fa5dd2 --- /dev/null +++ b/mtraining/trainer.py @@ -0,0 +1,678 @@ +import os +import time +import copy +import torch +import logging +import pandas as pd +from tqdm import tqdm +from dataclasses import dataclass +from datetime import timedelta +from collections import defaultdict +from dataclasses import asdict +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Union, Optional, Callable + +import torch.distributed +from torch.profiler import profile, ProfilerActivity + +import nnscaler +from nnscaler.utils import accum_mode +from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.module import ParallelModule +from nnscaler.utils import is_running_distributed +from nnscaler.cli.trainer import ( + Trainer, _StepStat, TrainerArgs, TrainStatus, AggregatedTrainHook, TrainHook +) + +from MTraining.utils.custom_parallel import parallelize as custom_parallelize +from MTraining.utils.val_utils import fix_model_state_dict +from MTraining.utils.paths import EXPR_DATA_SAVE_PATH + +logger = logging.getLogger(__name__) + +def save_qkv_start_iter_idx(): + if os.getenv("COLLECT_QKV_DATA", "0") == "1": + # check QKV save dir and return the maximum iter idex corresponding to which the complete set of QKV shards are saved + qkv_save_dir = os.path.join(os.getenv("QKV_STORE_DIR"), str(torch.distributed.get_world_size())) + print(f"Rank {torch.distributed.get_rank()} | {__name__} | qkv_save_dir={qkv_save_dir}") + if not os.path.exists(qkv_save_dir): + print(f"Rank {torch.distributed.get_rank()} | {__name__} | qkv_save_dir does not exist") + return -1 + + # check subdirectories in this save dir (each subdir is named as `sample_{iter_idx}`) + num_gpus, num_layers = int(os.getenv("ORIG_GPU_SET").split('_')[-1]), int(os.getenv("NUM_LAYERS")) + subdirs = [d for d in os.listdir(qkv_save_dir) if os.path.isdir(os.path.join(qkv_save_dir, d)) and d.startswith("sample_")] + print(f"Rank {torch.distributed.get_rank()} | {__name__} | subdirs={subdirs}") + for iter_idx in range(len(subdirs) - 1, -1, -1): + subdir = os.path.join(qkv_save_dir, f"sample_{iter_idx}") + + # check if all layers are saved in this subdir + layer_dirs = [os.path.join(subdir, d) for d in os.listdir(subdir) if d.startswith('layer_') and os.path.isdir(os.path.join(subdir, d))] + if len(layer_dirs) == num_layers: + # check if all GPUs are saved in this subdir + print(f"Rank {torch.distributed.get_rank()} | {__name__} | layer_dirs={layer_dirs}") + for layer_dir in layer_dirs: + shard_paths = [sp for sp in os.listdir(layer_dir) + if (sp.startswith('q_') or sp.startswith('k_') or sp.startswith('v_') or sp.startswith('dout_')) \ + and sp.endswith('.pt') \ + and os.path.isfile(os.path.join(layer_dir, sp))] + if len(shard_paths) == 4 * num_gpus: + # all shards are saved + return iter_idx + + return -1 + +@dataclass +class CustomTrainerArgs(TrainerArgs): + transfer_config: Optional[Dict[str, Any]] = None + merged_ckpt_path: Optional[str] = None + + + +EXPR_NAME: str +PT_LOG_SAVE_DIR: str + +ITERATOR_COUNTER = defaultdict(int) +def get_iter_cnt(rank: int): + global ITERATOR_COUNTER + return ITERATOR_COUNTER.get(rank, 0) + +ITER_BATCH_IDX_DICT = {} +def get_iter_batch_idx(rank: int, iter_cnt: int): + global ITER_BATCH_IDX_DICT + return ITER_BATCH_IDX_DICT.get(rank, {}).get(iter_cnt, 0) + +SAVE_ITERVAL = -1 +def need_save_data(rank: int): + global SAVE_ITERVAL + if SAVE_ITERVAL <= 0: return False + return get_iter_cnt(rank) % SAVE_ITERVAL == 0 + +EXECUTOR = ThreadPoolExecutor(max_workers=4) # Adjust max_workers as needed +def save_iter_losses(epoch_idx: int, iter_idx: int, losses: List[Any], latencies: Optional[List[Any]]): + if torch.distributed.get_rank() != 0: return + + loss_save_dir = os.path.join(EXPR_DATA_SAVE_PATH['base_path'], 'losses', f"epoch_{epoch_idx}") + os.makedirs(loss_save_dir, exist_ok=True) + loss_save_path = os.path.join(loss_save_dir, f'iter_{iter_idx}.csv') + print(f"Rank {torch.distributed.get_rank()} | {__name__} | Saving iter losses to {loss_save_path} ...") + + loss_dict = {} + for sample_idx in range(len(losses)): + loss_dict[sample_idx] = { + 'loss': losses[sample_idx][1].item(), + 'num_tokens': losses[sample_idx][2], + } + if latencies is not None: + loss_dict[sample_idx]['latency'] = latencies[sample_idx] + loss_df = pd.DataFrame.from_dict(loss_dict, orient='index') + loss_df.index.name = "Sample" + + loss_df.to_csv(loss_save_path) + +def prof_train_step( + model: ParallelModule, + rank: int, iter_idx: int, + samples: List[Any], + is_dummy_batch: Optional[List[bool]] = None, + scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, +) -> List[Any]: + global ITER_BATCH_IDX_DICT + model._warn_uninitialized_non_persistent_buffers(raise_error=True) + + if not model.compute_config.use_end2end: + raise RuntimeError("train_step() is only supported in end2end mode") + if is_dummy_batch and len(samples) != len(is_dummy_batch): + raise ValueError("The length of samples and is_dummy_batch should be the same") + + model._scale_loss(is_dummy_batch, scale_fn) + + # sync_grad will be done in _train_step + # so we never need to call it manually + model._sync_grad_required = False + sample_count = len(samples) + dataloader = microbatches(samples, cycle=False) + + outputs = [] + trace_path = os.path.join(PT_LOG_SAVE_DIR, f'iter_{iter_idx}', f'trace_{rank}.log') + os.makedirs(os.path.dirname(trace_path), exist_ok=True) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), + record_shapes=True, + with_stack=True + ) as prof: + for idx in range(sample_count): + ITER_BATCH_IDX_DICT[rank][iter_idx] = idx + + sample_start_time = time.perf_counter() + with accum_mode(begin=(idx==0), end=(idx==sample_count-1)): + output = model._train_step(dataloader) + sample_time = time.perf_counter() - sample_start_time + + if rank == 0: + print(f"| {__name__} | rank={rank} | iter_idx={iter_idx}, batch_idx={idx}, loss={output[1]}, latency={sample_time:.4f}s") + + outputs.append(output) + prof.step() + return outputs + +def custom_train_step( + model: ParallelModule, + rank: int, iter_idx: int, + samples: List[Any], + is_dummy_batch: Optional[List[bool]] = None, + scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, +) -> List[Any]: + """ + The training step function. It should be called in the training loop. + Please note: + 1. This function is only supported in end2end mode. + 2. Gradient accumulation is done inside this function. + You shouldn't do gradient accumulation outside this function, + because the gradients will be cleared in the beginning of this function + Args: + samples (List[Any]): a list of samples. + if pipeline is used, it must have the same length as configured to pas policy + is_dummy_batch (Optional[List[bool]]): indicates whether the each micro-batch is dummya + scale_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): the function to scale the loss + Results: + List[Any]: a list of outputs for each sample + """ + global ITER_BATCH_IDX_DICT + model._warn_uninitialized_non_persistent_buffers(raise_error=True) + + if not model.compute_config.use_end2end: + raise RuntimeError("train_step() is only supported in end2end mode") + if is_dummy_batch and len(samples) != len(is_dummy_batch): + raise ValueError("The length of samples and is_dummy_batch should be the same") + + model._scale_loss(is_dummy_batch, scale_fn) + + # sync_grad will be done in _train_step + # so we never need to call it manually + model._sync_grad_required = False + sample_count = len(samples) + dataloader = microbatches(samples, cycle=False) + + qkv_start_iter = save_qkv_start_iter_idx() + if model.use_scheduler: + if len(samples) != model.nmicros_per_scheduler_step: + raise ValueError(f"Expected {model.nmicros_per_scheduler_step} samples, but got {sample_count}") + # only one step, so begin/end are both True + with accum_mode(begin=True, end=True): + return model._train_step(dataloader), None + else: + outputs = [] + latencies = [] + for idx in range(sample_count): + ITER_BATCH_IDX_DICT[rank][iter_idx] = idx + if idx <= qkv_start_iter: continue + + sample_start_time = time.perf_counter() + with accum_mode(begin=(idx==0), end=(idx==sample_count-1)): + # loss, loss.data, samples['ntokens'], samples['nsentences'] + output = model._train_step(dataloader) + sample_time = time.perf_counter() - sample_start_time + latencies.append(sample_time) + + num_tokens = output[2] + if rank == 0: + print( + f"| {__name__} | rank={rank} | iter_idx={iter_idx}, batch_idx={idx}, loss={output[1] / num_tokens:.4f}," + f" num_tokens={num_tokens}, latency={sample_time:.4f}s" + ) + + outputs.append(output) + return outputs, latencies + +class CustomTrainer(Trainer): + def __init__( + self, + argv: Optional[List[str]] = None, + *, + train_args: Optional[Union[Dict[str, Any], CustomTrainerArgs]] = None, + save_data_steps: int = 1, + enable_prof: bool = False, + ): + """ + Custom trainer with an additional parameter. + + Args: + argv (Optional[List[str]]): Command line arguments. If not specified, sys.argv[1:] will be used. + train_args: A dict used to construct TrainerArgs or a TrainerArgs object itself. + additional_param (Optional[Any]): Additional parameter for custom functionality. + """ + # Call the parent class's initializer with the existing parameters + super().__init__(argv=argv, train_args=train_args) + self.train_args: CustomTrainerArgs + + if int(os.getenv('NCCL_DEBUG_MODE', '0')) == 1 and int(os.getenv("CODE_GEN", '0')) != 1: + torch.distributed.init_process_group( + backend='nccl', + timeout=timedelta(seconds=30), + ) + print(f"Rank {self.rank} | {__name__} | nccl timeout is set to 30s for debugging") + else: + torch.distributed.init_process_group( + backend='nccl', + timeout=timedelta(hours=2), + ) + + global SAVE_ITERVAL, EXPR_NAME, PT_LOG_SAVE_DIR + SAVE_ITERVAL = save_data_steps + self.save_data_steps = save_data_steps + + EXPR_NAME = train_args.instance_name + PT_LOG_SAVE_DIR = os.path.join(train_args.log[0].args['root_dir'].replace('tf_logs', 'pt_logs'), EXPR_NAME) + os.makedirs(PT_LOG_SAVE_DIR, exist_ok=True) + + self.enable_prof = enable_prof + if self.enable_prof: + self.train_step_func = prof_train_step + else: + self.train_step_func = custom_train_step + + def _train_epoch(self, epoch): + VAL_STATUS_NO = 0 # not validated or saved + VAL_STATUS_VAL = 1 # validated but not saved + VAL_STATUS_SAVE = 2 # validated and saved + has_validated = VAL_STATUS_NO # 3 states + + resume_from_idx = self.train_status.finished_train_steps % self.total_train_steps_per_epoch + data_iter = enumerate(self._global_batch_iterator(num_skip_first=resume_from_idx)) + + max_epoch = self.max_train_steps // self.total_train_steps_per_epoch + if self.max_train_steps % self.total_train_steps_per_epoch != 0: + max_epoch += 1 + ndigits = len(str(max_epoch)) + epoch_format = f"0{ndigits}d" + epoch_desc = f'Epoch {format(epoch, epoch_format)}' + + if self.rank == 0: + progress = tqdm( + None, + total=self.total_train_steps_per_epoch, + initial=resume_from_idx, + desc=epoch_desc, + disable=not self.train_args.enable_progress_bar, + ) + else: + progress = None + + + # --------------------------------------------------------------------------------- + train_info_save_path = os.path.join(EXPR_DATA_SAVE_PATH['base_path'], 'train_info', f"epoch_{epoch}.log") + os.makedirs(os.path.dirname(train_info_save_path), exist_ok=True) + if self.rank == 0: + # Check whether the file already exists + # If it exists, assume existing log file has name 'epoch__.log' ('epoch_.log` is assumed to have num 0) + # Find the greatest for the current epoch and increment it to build the new file name + existing_files = [f for f in os.listdir(os.path.dirname(train_info_save_path)) \ + if f.startswith(f'epoch_{epoch}_') or f.startswith(f'epoch_{epoch}.log')] + if existing_files: + # Extract the numbers from the filenames + existing_nums = [int(f.split('_')[-1].split('.')[0]) for f in existing_files if f.startswith(f'epoch_{epoch}_')] + if not existing_nums: + existing_nums = [0] + new_num = max(existing_nums) + 1 + train_info_save_path = os.path.join(os.path.dirname(train_info_save_path), f'epoch_{epoch}_{new_num}.log') + else: + # If no existing files, use the original path + train_info_save_path = os.path.join(os.path.dirname(train_info_save_path), f'epoch_{epoch}.log') + with open(train_info_save_path, 'w') as f: f.write('') + + step_stat: Optional[_StepStat] = None + num_tokens_trained = 0 + for i, batches in data_iter: + idx = i + resume_from_idx + + global ITERATOR_COUNTER, ITER_BATCH_IDX_DICT + ITERATOR_COUNTER[self.rank] = idx + ITER_BATCH_IDX_DICT[self.rank] = {idx: 0} + # print(f"|{__name__}| rank={self.rank}, ITERATOR_COUNTER[self.rank]={ITERATOR_COUNTER[self.rank]}") + + if self.rank == 0: + # looks manually update progress bar is easier + # than using tqdm directly + # the difference is we update progress bar at the beginning of the loop + # instead of the end of the loop + progress.update(1) + step_start_at = time.perf_counter() + step_stat = _StepStat() + step_metrics = {} + has_validated = VAL_STATUS_NO + num_batches = len(batches) + batches, is_dummy_batch = self._fix_batches(batches) + + self.model.train() + + self.hook.before_zero_grad(self) + self.optimizer.zero_grad() + self.hook.after_zero_grad(self) + + self.hook.on_train_step_start(self, batches[:num_batches], idx) + # losses = self.model.train_step(batches, is_dummy_batch) + losses, latencies = self.train_step_func(self.model, self.rank, idx, batches, is_dummy_batch) + # EXECUTOR.submit( + # save_iter_losses, + # idx, losses, latencies + # ) + # save_iter_losses(idx, losses, latencies) + self.hook.on_train_step_end(self, losses[:num_batches], batches[:num_batches], idx) + + aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs + aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) + if self.train_args.optimizer.loss_reduction == 'mean': + loss = aggregated_outputs.loss_sum / aggregated_outputs.num_batches + else: + loss = aggregated_outputs.loss_sum + step_stat.train_loss = loss + num_tokens_trained += aggregated_outputs.num_tokens + self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) + + self.hook.before_sync_grad(self) + # actually `sync_shard_grad` is no-op here + # because trainer only supports end2end model + # and syncing grad in end2end model is done in `_train_step`. + self.optimizer.sync_shard_grad() + self.hook.after_sync_grad(self) + + # scale gradients + multiplier = self.train_args.scaling_factor + if self.train_args.optimizer.grad_reduction == 'sum': + # do nothing. `multiplier` is already correct + pass + elif self.train_args.optimizer.grad_reduction == 'mean': + if not aggregated_outputs.num_batches: + raise RuntimeError("`aggregate_outputs` doesn't set `num_batches` field") + multiplier /= aggregated_outputs.num_batches + else: + assert self.train_args.optimizer.grad_reduction == 'per-token-mean' + if not aggregated_outputs.num_tokens: + raise RuntimeError("`aggregate_outputs` doesn't set `num_tokens` field") + multiplier /= aggregated_outputs.num_tokens + self.optimizer.scale_grads(multiplier) + + # clip gradients + self.hook.before_gnorm_clip(self) + if self.train_args.optimizer.clip_gnorm: + step_stat.gnorm = self.optimizer.clip_gnorm(self.train_args.optimizer.clip_gnorm) + else: + step_stat.gnorm = self.optimizer.clip_gnorm() + self.hook.after_gnorm_clip(self, step_stat.gnorm) + step_stat.gnorm = step_stat.gnorm.item() + + # update parameters + step_stat.lr = self.optimizer.param_groups[0]['lr'] + self.hook.before_optimizer_step(self) + self.optimizer.step() + self.hook.after_optimizer_step(self) + if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': + self.lr_scheduler.step() + + self.train_status.finished_train_steps += 1 + self._log_mem_stats(tag='train') + step_metrics = {k:v for k, v in asdict(step_stat).items() if v is not None} + step_metrics['train_wall'] = time.perf_counter() - step_start_at + step_metrics['num_tokens_processed'] = num_tokens_trained + self.log_metrics(step_metrics, tag='train') + if self.rank == 0: + progress.set_postfix(step_metrics) + formatted_metrics = self._format_metrics(epoch_desc, idx + 1, step_metrics) + with open(train_info_save_path, 'a') as f: + f.write(f"{formatted_metrics}\n") + + if self.train_args.enable_log_progress \ + and self.train_status.finished_train_steps % self.train_args.log_progress_every_n_train_steps == 0: + + logger.info(formatted_metrics) + step_metrics = {} + + # validate and save checkpoint + if self.train_args.checkpoint.every_n_train_steps and \ + self.train_status.finished_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + + # max_train_steps is reached + if self.train_status.finished_train_steps >= self.max_train_steps: + if step_metrics and self.train_args.enable_log_progress: + logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) + step_metrics = {} + if not has_validated: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + if self.rank == 0: + # disable refresh the progress bar to avoid redundant progress bar + progress.leave = False + progress.close() + break + + if not has_validated and self.train_args.val_every_n_train_steps and \ + self.train_status.finished_train_steps % self.train_args.val_every_n_train_steps == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL + + # time.sleep(1) + else: + # Do per-epoch operations here. + # if the loop exits with `break` (max_train_steps is reached) + # those operations have done in the loop + if step_stat is None: + return # no train step runs. Nothing to do. + if has_validated < VAL_STATUS_SAVE \ + and self.train_args.checkpoint.every_n_epochs \ + and (epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + if not has_validated and self.train_args.val_every_n_epochs \ + and (epoch + 1) % self.train_args.val_every_n_epochs == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL + + def _setup(self): + self.train_args.init_env(self) + compile_only = self.train_args.compile_mode + + if is_running_distributed(): + nnscaler.init() + if torch.distributed.get_rank() == 0: + logging.getLogger().setLevel(logging.INFO) + else: + logging.getLogger().setLevel(logging.WARNING) + + def _create_model(): + model = self.train_args.create_model() + if self.train_args.param_dtype == self.train_args.buffer_dtype: + if self.train_args.param_dtype is not None: + model = model.to(self.train_args.param_dtype) + else: + # separate param and buffer dtype + # TODO: a little hacky. A better way? + # 3 kinds of tensors are converted in Module._apply: + # model parameters, its grad, and buffer + # param_dtype controls the first two, (but grad is `None` here) + # and buffer_dtype controls the last one + buf_ids = { id(buf) for buf in model.buffers(recurse=True) } + if self.train_args.param_dtype is not None: + model._apply( + lambda t: t.to(self.train_args.param_dtype) + if t.is_floating_point() and id(t) not in buf_ids + else t) + if self.train_args.buffer_dtype is not None: + model._apply( + lambda t: t.to(self.train_args.buffer_dtype) + if t.is_floating_point() and id(t) in buf_ids + else t) + if self.train_args.tracing_from_weights: + model.load_state_dict(torch.load(self.train_args.tracing_from_weights)) + return model + + # create dataset and dataloader + for stage in ['train', 'val', 'test']: + self.dataset[stage] = self.train_args.create_dataset(stage) + + # load a dummy input from training dataset + self.dummy_input = self._load_dummy_input() + self.dummy_input = self._fix_input(self.dummy_input) + + for stage in ['train', 'val', 'test']: + self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) + if self.dataloader[stage] is not None \ + and not self.dataloader[stage].drop_last \ + and len(self.dataset[stage]) % (self.train_args.micro_batch_size * self.train_args.scaling_factor) != 0: + warnings.warn( + f"Length of {stage} dataset ({len(self.dataset[stage])}) " + f"is not multiple of micro_batch_size * scale_factor ({self.train_args.micro_batch_size * self.train_args.scaling_factor}). " + f"In this case, the train_step for the last batch of samples can fail! " + f"You can specify `drop_last=True` in DataLoader to fix this problem." + ) + + # setup compute config + compute_config = copy.deepcopy(self.train_args.compute_config) + compute_config.pas_config['__pas_name'] = self.train_args.pas_policy + # autodist configs + compute_config.pas_config['update_freq'] = self.train_args.update_freq + compute_config.pas_config['use_bf16'] = self.train_args.param_dtype == torch.bfloat16 + compute_config.pas_config['use_fp16'] = self.train_args.param_dtype == torch.float16 + + compute_config.user_config['__from_trainer_args'] = { + 'mbs': self.train_args.micro_batch_size, + 'gbs': self.train_args.global_batch_size, + 'precision': self.train_args.precision, + 'model_args': self.train_args.model.args, + } + + # parallalize model + pmodel_class = custom_parallelize( + self.train_args.model_type, + self._create_dummy_forward_args(), + self.train_args.resolved_pas_policy, + compute_config, + module_fn=_create_model, + gen_savedir=self.train_args.gen_savedir, + reuse=self.train_args.gen_reuse, + instance_name=self.train_args.instance_name, + broadcast_strategy=self.train_args.broadcast_strategy, + load_module=not compile_only, + transfer_config=self.train_args.transfer_config, + ) + if compile_only: + return + + torch.distributed.barrier() + self.rank = torch.distributed.get_rank() + + self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq + if len(self.dataloader['train']) % self.train_args.update_freq != 0: + self.total_train_steps_per_epoch += 1 # will add extra dummy batches + + if self.train_args.max_epochs and self.train_args.max_train_steps: + self.max_train_steps = min( + self.total_train_steps_per_epoch * self.train_args.max_epochs, + self.train_args.max_train_steps + ) + elif self.train_args.max_train_steps: + self.max_train_steps = self.train_args.max_train_steps + else: + assert self.train_args.max_epochs, "max_epochs or max_train_steps should be specified" + self.max_train_steps = self.total_train_steps_per_epoch * self.train_args.max_epochs + + _, self.sync_group = self.train_args.compute_config.get_sync_group() + self.model = pmodel_class() + self.model.cuda() + self.optimizer = self.train_args.create_parallel_optimizer(self.model) + # Here we carefully scale down the gradient locally with 1/scale_factor before reduce, + # (the reduce op is `sum` by default, follow torch's c10d, grad is divided by scaling_factor before allreduce) + # and scale up the gradient after reduce + # (see `train_args.optimizer.grad_reduction`` handling in `train_epoch`). + # This is useful to avoid overflow when the gradients are large. + def reducer_pre_hook(reducer, grad): + grad.div_(self.train_args.scaling_factor) + self.optimizer.register_reducer_pre_hook(reducer_pre_hook) + self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) + self.loggers = self.train_args.create_loggers() + + supported_hook_components = [ + self.model, + self.optimizer, + self.lr_scheduler, + ] + self.hook = AggregatedTrainHook( + [x for x in supported_hook_components if isinstance(x, TrainHook)] + + [self.train_args.create_hook()] + ) + + self._log_config(self.train_args.to_dict()) + self._load_checkpoint() + + if self.train_args.merged_ckpt_path is not None: + print(f"Rank {self.rank} | {__name__} | loading merged checkpoint from {self.train_args.merged_ckpt_path}") + merged_ckpt_path = os.path.join(self.train_args.merged_ckpt_path, "pytorch_model.bin") + model_state_dict = torch.load(merged_ckpt_path, map_location='cpu') + + first_key = list(model_state_dict.keys())[0] + if len(first_key.split('.')) == 1: + # For Ring-Attention models, the merged checkpoint is directly copied from one of the shards and has different key names. + model_state_dict = fix_model_state_dict(self.model, model_state_dict) + + first_key = list(model_state_dict.keys())[0] + if 'model.model' not in first_key: + # Our merging logic also removes the prefix `model.` from the state dict keys when saving + model_state_dict = {'model.' + k: v for k, v in model_state_dict.items()} + if self.rank % int(os.getenv("GPU_PER_NODE", "8")) == 0: + print(f"Rank {self.rank} | {__name__} | loaded model state dict.keys(): {model_state_dict.keys()}") + + # in our merge program, `model` is poped out and we directly pass the model_state_dict instead of model_state_dict['model'] + nnscaler.load_merged_state_dict( + self.model, model_state_dict, + self.optimizer, None, + ) + + self.hook.after_setup(self) + + def _load_checkpoint(self): + resume_from = self.train_args.checkpoint.get_resume_checkpoint_dir() + if not resume_from: + return + logger.info(f"Resuming from {resume_from}") + if resume_from.is_file(): + resume_from = resume_from # when we load from merged checkpoint + else: + resume_from = resume_from / f'{self.rank}.ckpt' + state_dict = torch.load(resume_from, map_location='cpu') + self.hook.on_load_checkpoint(self, state_dict) + ckpt_save_type = state_dict['train_args']['checkpoint']['save_type'] + + if ckpt_save_type == 'merged': # it is a merged state dict + nnscaler.load_merged_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + elif ckpt_save_type == 'sharded': + nnscaler.load_sharded_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + elif ckpt_save_type == 'deduped': + nnscaler.load_deduped_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + else: + raise ValueError(f"Unknown checkpoint type: {ckpt_save_type}") + + if 'lr_scheduler' in state_dict: + if state_dict['lr_scheduler'] and not self.lr_scheduler: + raise ValueError("lr_scheduler is not set in the current trainer") + if self.lr_scheduler: + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + self.train_status = TrainStatus(**state_dict['train_status']) + + # Assume in efficiency measuring mode, no checkpoint is saved and we only need to load the originally trained checkpoint + if int(os.getenv("E2E_MEASURE", "0")) == 1 or int(os.getenv("COLLECT_RING_COMP_DATA", "0")) == 1: + self.train_status.finished_train_steps = 0 + + self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() \ No newline at end of file diff --git a/mtraining/utils/__init__.py b/mtraining/utils/__init__.py new file mode 100644 index 0000000..fc13228 --- /dev/null +++ b/mtraining/utils/__init__.py @@ -0,0 +1,2 @@ +from .loss import chunk_linear_cross_entropy, linear_cross_entropy +from .general import * \ No newline at end of file diff --git a/mtraining/utils/custom_parallel.py b/mtraining/utils/custom_parallel.py new file mode 100644 index 0000000..cd3d4bf --- /dev/null +++ b/mtraining/utils/custom_parallel.py @@ -0,0 +1,467 @@ +import os +import torch +import shutil +import inspect +import torch.distributed as dist + +from pathlib import Path +from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List, Set, Literal + +from nnscaler.graph import IRGraph +from nnscaler.graph.parser import FxModuleParser +from nnscaler.runtime.device import DeviceGroup +from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState + +from nnscaler.parallel import ( + ComputeConfig, ReuseType, BroadcastGenFilesStrategy, RegenStatus, + _prepare_namespace, _compile_flags, _clean_files, _is_any_gencode_loaded, _gencode, + _broadcast_gen_files, _load_parallel_module_class, + _GENCODE_FILE_TEMPLATE, _GRAPH_DUMP_FILE, _FORWARD_ARGS_DUMP_FILE, _PREDEFINED_POLICIES + +) + +import logging +logger = logging.getLogger(__name__) + +def compute_config_safe_equals(a: Optional['ComputeConfig'], b: Optional['ComputeConfig']) -> bool: + """ + Return False if a and b are from incompatible version of ComputeConfig + This is only for backward compatibility, and will be removed in future + and can use `==` when we save dict version of ComputeConfig to file. + """ + res = True + try: + for key in a.__dataclass_fields__: + if getattr(a, key) != getattr(b, key): + print(f"{key} not equal: {getattr(a, key)} (old_config) != {getattr(b, key)} (current_config)") + + if key == "user_config": + continue + else: + print(f"compute_config_safe_equals | {key} not equal: {getattr(a, key)} (old_config) != {getattr(b, key)} (current_config)") + res = False + return res + except AttributeError: + logger.warning("Failed to compare ComputeConfig. They are incompatible.") + return False + +GRAPH_CONFIG_FIELDS = ['constant_folding', 'user_config', 'inference_only', 'end2end_mode', 'trace_strategy'] +def graph_config_equals(a: Dict[str, Any], b: Dict[str, Any]) -> bool: + """ + Return False if a and b are from incompatible version of ComputeConfig + This is only for backward compatibility, and will be removed in future + and can use `==` when we save dict version of ComputeConfig to file. + """ + res = True + try: + for key in GRAPH_CONFIG_FIELDS: + if getattr(a, key) != getattr(b, key): + print(f"graph_config_equals | {key} not equal: {getattr(a, key)} (old_config) != {getattr(b, key)} (current_config)") + if key != "user_config": + res = False + return res + except AttributeError: + logger.warning("Failed to compare GraphConfig. They are incompatible.") + return False + + +TRACE_FILE_EXTENSIONS = [ + FxModuleParser.ATTR_CONTENT_FILE_0, # init weights file(fullmodel.pt.*), + FxModuleParser.ATTR_MAP_FILE, # param name mapping (dist_param_map.pt)\ + _GRAPH_DUMP_FILE, # graph dump (graph.ckp), + _FORWARD_ARGS_DUMP_FILE, # forward args dump(forward_args.pkl), + ParallelModule.ORIGIN_MODULE_METADATA_FILE # origin module metadata (origin_module_metadata.pt), +] +def transfer_metadata(out_dir, transfer_config: Dict[str, Any]): + transfer_config_dir, transfer_force = transfer_config['transfer_config_dir'], transfer_config['transfer_force'] + if not os.path.exists(transfer_config_dir): + # if transfer_config_dir is not set, use the default directory + transfer_config_dir = transfer_config_dir.replace("compile_config/", "compile_config/rank_0/") + assert os.path.exists(transfer_config_dir), f"Source directory {transfer_config_dir} for transferring does not exist" + + # transfer files in src_dir with postfix not being .py by executing `cp` + print(f"{__name__} | Transfering files from {transfer_config_dir} to {out_dir} (local_rank={os.getenv('LOCAL_RANK')})") + for file in os.listdir(transfer_config_dir): + if file in TRACE_FILE_EXTENSIONS or file.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM): + src_file = os.path.join(transfer_config_dir, file) + dst_file = os.path.join(out_dir, file) + + print(f"{__name__} | Copying {src_file} to {dst_file} (local_rank={os.getenv('LOCAL_RANK')})" ) + if not os.path.exists(dst_file) or transfer_force: + shutil.copyfile(src_file, dst_file) + + if not os.path.exists(dst_file): + raise FileNotFoundError(f"{__name__} | Copy failed ({dst_file} does not exist after copying)") + + # Create a file 'transferred.sign' to indicate that the transfer is done + with open(os.path.join(out_dir, "transferred.sign"), 'w') as f: + f.write("Transferred from " + transfer_config_dir) + +def _prepare_and_check_reusable( + gen_savedir: str, + module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], + compute_config: ComputeConfig, + instance_name: Optional[str] = None, + reuse: ReuseType = ReuseType.MATCH, + transfer_config: Dict[str, Any] = None, + ) -> Tuple[str, bool, bool]: + """ + Prepare the output directory for code generation, and also check if the existing code is reusable. + + Args: + gen_savedir (str): the directory to save generated code + module_or_module_class (Union[Type[torch.nn.Module], torch.nn.Module]): the original module or module class + compute_config (ComputeConfig): the environment resource + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. + reuse (ReuseType): specify which part can be reused. + + Returns: + Tuple[str, bool]: the output directory and whether the existing code is reusable. + + Raises: + RuntimeError: if the existing code is not reusable, + will raise RuntimeError if the code is not reusable but the module is already loaded. + """ + namespace, outdir = _prepare_namespace(gen_savedir, module_or_module_class, instance_name) + reusable = False + transferred = False + + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + + # Empty + Transfer -> config match, graph match, tracing file present -> generate code by MATCH or MOO + # Empty w.o. Transfer -> Empty -> generate code by MATCH or MOO + has_transferred = os.path.exists(os.path.join(outdir, "transferred.sign")) + if transfer_config is not None and transfer_config.get("transfer_config_dir", None) is not None \ + and (not has_transferred or transfer_config['transfer_force']): + # transfer_config_dir: Optional[str] = None, + transfer_metadata(outdir, transfer_config) + ComputeConfig.safe_dump_to_file(compute_config, config_file) + transferred = True + + # decision matrix for code generation + # reuse flag | dir condition(imported, empty, match, unmatched) | action + # --------------------------------------------------------- + # OVERRIDE | empty | generate + # OVERRIDE | imported | raise error + # OVERRIDE | whatever match | generate + # OVERRIDE | unmatch | generate + # GRAPH | empty | generate + # GRAPH | imported | raise error + # GRAPH | graph match | reuse graph, and regenerate code + # GRAPH | all match | reuse graph, and regenerate code + # GRAPH | unmatch | generate + # MATCH | empty | generate + # MATCH | match | reuse(do nothing) + # MATCH* | whatever unmatch| raise error (except when there's no python source code, see below) + # MATCH | imported | doesn't matter + # MOO | empty | generate + # MOO | match | reuse(do nothing) + # MOO | match graph | reuse graph, and regenerate code + # MOO | imported | raise error if whatever unmatch + # *: The precondition for `except` part is the compute config should match. + # you can take it as a continous operation after a failed generation. + old_config: Optional[ComputeConfig] = ComputeConfig.safe_load_from_file(config_file) + is_config_match = compute_config_safe_equals(old_config, compute_config) + # is_graph_config_match = old_config is not None and old_config.graph_config == compute_config.graph_config + is_graph_config_match = old_config is not None and graph_config_equals(old_config.graph_config, compute_config.graph_config) + trace_meta_files = [ + outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # init weights file(fullmodel.pt.*), + outdir / FxModuleParser.ATTR_MAP_FILE, # param name mapping (dist_param_map.pt) + ] + + if reuse == ReuseType.MATCH or reuse == ReuseType.MOO: + # check if the module is already generated + expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] + expected_output_files.extend(trace_meta_files) + expected_output_files.append(config_file) + expected_output_files.append(outdir / _GRAPH_DUMP_FILE) # graph dump (graph.ckp), + expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) # forward args dump(forward_args.pkl), + expected_output_files.append(outdir / ParallelModule.ORIGIN_MODULE_METADATA_FILE) # origin module metadata (origin_module_metadata.pt), + existing_output_files = [ + f for f in outdir.glob('*') + if f.is_file() and ( # just take fullmodel.pt.0 to compare + not f.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) + or f.name == FxModuleParser.ATTR_CONTENT_FILE_0 + ) and not f.name.endswith('.sign') + ] + + print(f"{__name__} | compute config match: {is_config_match}") + print(f"{__name__} | graph config match: {is_graph_config_match}") + print(f"{__name__} | existing output files: {existing_output_files}") + print(f"{__name__} | expected output files: {expected_output_files}") + + if existing_output_files: # if the directory is not empty + if is_config_match \ + and all([output_file.exists() for output_file in expected_output_files]) \ + and len(existing_output_files) == len(expected_output_files): + + print(f"{__name__} | Reuse existing files in {outdir}") + reusable = True # everything is matched. + elif is_config_match \ + and all(f.suffix != '.py' for f in existing_output_files): + # No python source code is generated. + # which means its last generation failed. + # in this case, we can reuse the same directory safely. + logger.info(f'Output directory {outdir} is not empty. ' + f'But no python source code is present. ' + f'Will reuse the directory and the graph dump if present.') + # we have to trace the graph again if not all meta files are present. + print(f"{__name__} | compute config match but no python code exists in {outdir}") + if not all([meta_file.exists() for meta_file in trace_meta_files]): + print(f"{__name__} | compute config match but no python code exists in {outdir} and not all meta files are present") + _clean_files(outdir) + elif reuse == ReuseType.MATCH: + raise RuntimeError(f'Output directory {outdir} is not empty. ' + f'And the existing files do not match with current config. ' + f'You can remove the directory and try again, ' + f'or set reuse to ReuseType.NONE/ReuseType.OVERRIDE to regenerate the code.') + else: + assert reuse == ReuseType.MOO + if _is_any_gencode_loaded(namespace): + raise RuntimeError(f'Output directory {outdir} is already loaded. ' + f'You can not override a loaded module.') + elif is_graph_config_match: + # reuse the graph dump + print(f"{__name__} | MOO | graph match -> reuse graph but clean the current code") + _clean_files(outdir, '*.py') + else: + _clean_files(outdir) + else: + # check if the module is already loaded + if _is_any_gencode_loaded(namespace): + raise RuntimeError(f'Output directory {outdir} is already loaded. ' + f'You can not override a loaded module.') + # clear existing generated files + if reuse == ReuseType.OVERRIDE \ + or not is_graph_config_match \ + or not all([meta_file.exists() for meta_file in trace_meta_files]): + # we have to trace the graph again if not all meta files are present even when reuse=graph. + print(f"{__name__} | OVERRIDE | Override existing files in {outdir}") + glob_pattern = '*' + else: + print(f"{__name__} | GRAPH | keep the graph dump in {outdir} and regenerate the code") + glob_pattern = '*.py' # so we can keep graph dumps. + _clean_files(outdir, glob_pattern) + + return outdir, reusable, transferred + + +def parallelize( + module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], + dummy_forward_args: Dict[str, Any], + pas_policy: Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]], + compute_config: ComputeConfig, + *, + gen_savedir: Union[str, Path] = './.nnscaler', + reuse: Union[ReuseType, str] = ReuseType.MATCH, + instance_name: Optional[str] = None, + load_module: bool = True, + module_dtype: Optional[torch.dtype] = None, + module_fn: Optional[Callable[[], torch.nn.Module]] = None, + init_module_params: bool = True, + broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', + transfer_config: Optional[Dict[str, Any]] = None, +) -> Union[None, ParallelModule, Type[ParallelModule]]: + """ + Convert a torch.nn.Module object or class to ParallelModule object or class. + + If you want to save multiple instances of the same module, + you can specify the instance_name to distinguish them. + + Currently you must use a shared file system to share the generated files (like mounted Azure Blob) + Or you can unset load_module flag, and manually copy the generated files to other nodes. + After all nodes have the generated files, you can call parallelize() again with load_module flag set. + + Note: if reuse is not set to ReuseType.MATCH, + the generated code in outdir will be removed EVEN IF the code generation fails in this call. + + if the input is a module object. + * The module object will be copied to cpu to handle possible insufficient gpu memory. + * The training flag will be the same as the original module + + This function can be used to convert both module object and module class to parallel module or parallel module class. + Among key-value arguments, + module_fn and module_dtype control how to create the module object. + whereas init_module_params controls how to load parallel module object after conversion is done. + + 1. If the input is a module object, it will return a ParallelModule object if load_module is True. + This is useful when the module is created by a factory function. + + a. module_fn is ignored. + b. module_dtype is used to control the dtype of the input module. + c. init_module_params is used to control whether to initialize the parallel module parameters when load it. + + 2. If the input is a module class, it will return a ParallelModule sub class if load_module is True. + + a. module_fn is used to create the module object, or module's__init__ if not prent. + b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). + Of course, it can be merged into module_fn. + c. init_module_params is ignored. + + After the module is converted, you can use it to create module object by calling it like a module class. + The module class is defined like: + + :: + + class GenModule(nnscaler.runtime.module.ParallelModule): + def __init__(self, init_params=True): + super().__init__() + ... + ... + + So you can use `init_params` in `__init__` to control whether to initialize the module parameters. + For example, if you don't want to initialize module params: + + :: + + module = GenModule(init_params=False) + + Args: + module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled + dummy_forward_args (Dict[str, Any]): the dummy input for the module forward + pas_policy (Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]): the pas policy, + it can be a name of builtin policies, or a custom policy function. + compute_config (ComputeConfig): the environment resource + reuse (ReuseType): specify which part can be reused. + gen_savedir (Union[str, Path]): the directory to save generated code + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. + load_module (bool): whether to load the generated module or module class after conversion is done. + init_module_params (bool): If true, when we construct the module, all its parameters are initialized with the same value with when we traced. + Otherwise, they will be empty tensor. + This parameter will be passed to the module constructor, + so it is only used when module_or_module_class is a module object, and load_module is true. + module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. + module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. + broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. + Please note that the broadcasting will only be done in torchrun environment, + and will throw an error if dist is not initialized and broadcast_strategy is not NONE. + Returns: + Union[ParallelModule, Type[ParallelModule], None]: + if load_module flag is set, return the converted ParallelModule object or class + if load_module flag is not set, return None + """ + if ( + isinstance(module_or_module_class, ParallelModule) or + (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, ParallelModule)) + ): + # already done + return module_or_module_class if load_module else None + + if ( + isinstance(module_or_module_class, CubeModule) or + (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, CubeModule)) + ): + raise RuntimeError("Old style CubeModule is not supported") + + if isinstance(pas_policy, str): + if not pas_policy in _PREDEFINED_POLICIES: + raise ValueError(f"Invalid pas_policy: {pas_policy}") + pas_policy = _PREDEFINED_POLICIES[pas_policy] + + is_module_class = inspect.isclass(module_or_module_class) + module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ + reuse = ReuseType(reuse) if isinstance(reuse, str) else reuse + broadcast_strategy = BroadcastGenFilesStrategy(broadcast_strategy) if isinstance(broadcast_strategy, str) else broadcast_strategy + + # Call it here just to ensure the device group is initialized. + # If the user initializes dist + # and doesn't call `nnscaler.init()` before calling this function, this is necessary. + if dist.is_initialized(): + _ = DeviceGroup() + + # generate code only in node0 + # if it is not in a torchrun environment, just generate. + if not dist.is_initialized() or dist.get_rank() == 0: + outdir, reusable, transferred = _prepare_and_check_reusable( + gen_savedir, module_class, compute_config, instance_name, reuse, + transfer_config + ) + if not reusable: + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + ComputeConfig.safe_dump_to_file(compute_config, config_file) # always refresh compute config + with _compile_flags(compute_config): + regen_status = _gencode( + module_or_module_class, + dummy_forward_args, + pas_policy, + compute_config, + outdir, + module_dtype=module_dtype, + module_fn=module_fn, + ) + else: + regen_status = RegenStatus.NONE + logger.info(f"Reuse generated code in {outdir}") + + if regen_status == RegenStatus.CODE and transferred: + regen_status = RegenStatus.ALL + + if dist.is_initialized(): + # code generation can take very long time (for example, over 1 hour) + # It is not always OK to use dist.barrier() directly. + # because the default timeout for nccl is 30 minutes + # (we can't control the timeout setting if dist is not initialized by us) + DeviceGroup().long_barrier() + + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + if not dist.is_initialized(): # we only support loading in torchrun environment + raise RuntimeError("Broadcast generated files failed: dist is not initialized.") + dist.barrier() + # sync regen_status + curr_rank = dist.get_rank() + if curr_rank == 0: + sent_obj = [regen_status] + else: + sent_obj = [None] + dist.broadcast_object_list( + sent_obj, + src=0, + ) + if curr_rank != 0: + regen_status = sent_obj[0] + + # narrow down broadcast_strategy according to regen_status + if regen_status == RegenStatus.NONE: + # we don't need to broadcast anything + broadcast_strategy = BroadcastGenFilesStrategy.NONE + elif regen_status == RegenStatus.CODE: + # narrow ALL/NO_WEIGHTS down to code + broadcast_strategy = BroadcastGenFilesStrategy.CODE + else: + # we don't need to narrow broadcast_strategy in this case + # keep the original broadcast_strategy + assert regen_status == RegenStatus.ALL + + # broadcast generated files according to regen_status + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + _broadcast_gen_files( + module_class, + gen_savedir=gen_savedir, + instance_name=instance_name, + broadcast_strategy=broadcast_strategy, + ) + elif os.getenv("FORCE_BROADCAST") == "1": + # force broadcast generated files + print(f"Force broadcast generated files in {gen_savedir}") + _broadcast_gen_files( + module_class, + gen_savedir=gen_savedir, + instance_name=instance_name, + broadcast_strategy=BroadcastGenFilesStrategy.ALL, + ) + + if load_module: + if not dist.is_initialized(): # we only support loading in torchrun environment + raise RuntimeError("Load ParallelModule failed: dist is not initialized.") + dist.barrier() + parallel_module_class = _load_parallel_module_class( + module_class, + gen_savedir=gen_savedir, + instance_name=instance_name, + ) + if is_module_class: + return parallel_module_class + else: + parallel_module = parallel_module_class(init_module_params) + parallel_module.train(module_or_module_class.training) # set training state to the same as original module + return parallel_module diff --git a/mtraining/utils/data_utils/__init__.py b/mtraining/utils/data_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtraining/utils/data_utils/bookcorpus.py b/mtraining/utils/data_utils/bookcorpus.py new file mode 100644 index 0000000..69429a5 --- /dev/null +++ b/mtraining/utils/data_utils/bookcorpus.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy +import torch +import argparse +from typing import List, Dict +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer, PreTrainedTokenizer + +def get_tokenizer(model_path): + return AutoTokenizer.from_pretrained(model_path) + +BOS_TOKEN = '' + +def tokenize(sample: Dict[str, str], tokenizer: PreTrainedTokenizer, text_key: str): + input_ids = tokenizer.encode(BOS_TOKEN + sample[text_key] + tokenizer.eos_token, add_special_tokens=False) + return {"input_ids": input_ids} + +def concate_split(samples: Dict[str, List[List[int]]], sample_len: int, text_key: str): + buffer = samples[text_key][0] + resized_ids = [] + length = [] + for in_ids in samples[text_key]: + buffer.extend(in_ids) + while len(buffer) >= sample_len: + resized_ids.append(buffer[:sample_len]) + length.append(sample_len) + buffer = buffer[sample_len:] + return {"input_ids": resized_ids, "length": length} + +def create_dataset(tokenizer: PreTrainedTokenizer, raw_dataset: Dataset, text_key: str, sample_len: int = 8 * 1024, batch_size=10000): + tokenized_dataset = raw_dataset.map( + tokenize, remove_columns=raw_dataset.column_names, num_proc=32, + fn_kwargs={'tokenizer': tokenizer, 'text_key': text_key} + ) + return tokenized_dataset.map( + concate_split, remove_columns=tokenized_dataset.column_names, + num_proc=32, batched=True, + batch_size=batch_size, fn_kwargs={'sample_len': sample_len, 'text_key': 'input_ids'} + ) + + +def modify_bos_token(tokenizer: PreTrainedTokenizer): + # https://huggingface.co/Qwen/Qwen2-7B-Instruct/discussions/15 + global BOS_TOKEN + if tokenizer.bos_token is None: + BOS_TOKEN = "<|endoftext|>" + else: + BOS_TOKEN = tokenizer.bos_token + +if __name__ == '__main__': + # python bookcorpus.py --data_path_or_name "bookcorpus/bookcorpus" --tokenizer_path_or_name "meta-llama/Llama-2-7b-hf" --save_path "bookcorpus-llama2-2k-hf" --sequence_length 2048 + parser = argparse.ArgumentParser() + parser.add_argument('--data_path_or_name', help='the path or name of the raw dataset, for exmaple, "bookcorpus/bookcorpus"', type=str, required=True) + parser.add_argument('--tokenizer_path_or_name', help='the tokenizer path or name, for example, "meta-llama/Llama-2-7b-hf"', type=str, required=True) + parser.add_argument('--save_path', help='the path to save the tokenized dataset', type=str, required=True) + parser.add_argument('--sequence_length', help='the length of each sample in the tokenized dataset, usually set to the max sequence length', type=int, required=True) + args = parser.parse_args() + + data_path_or_name = args.data_path_or_name + tokenizer_path_or_name = args.tokenizer_path_or_name + save_path = args.save_path + sequence_length = args.sequence_length + + raw_dataset = load_dataset(data_path_or_name)["train"] + tokenizer = get_tokenizer(tokenizer_path_or_name) + modify_bos_token(tokenizer) + + dataset = create_dataset(tokenizer, raw_dataset, "text", sequence_length) + dataset.save_to_disk(save_path) diff --git a/mtraining/utils/general.py b/mtraining/utils/general.py new file mode 100644 index 0000000..e5e1869 --- /dev/null +++ b/mtraining/utils/general.py @@ -0,0 +1,137 @@ +import os +import torch +import torch.distributed as dist + +from typing import List +from transformers import AutoModelForCausalLM, AutoTokenizer + +from nnscaler.cli.trainer_args import AggregatedOutputs +from nnscaler.runtime.module import ParallelModule +from .paths import ACTIVE_PARAM_CONFIG_DIR, BASE_DIR + +import logging +logger = logging.getLogger(__name__) + +def get_tokenizer(tokenizer_name_or_path, + model_max_length=None, + default_bos_token="", + default_eos_token="", + default_pad_token="[PAD]", + default_unk_token=""): + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True) + special_tokens_dict = dict() + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = default_pad_token + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = default_eos_token + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = default_bos_token + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = default_unk_token + + tokenizer.add_special_tokens(special_tokens_dict) + if model_max_length: + tokenizer.model_max_length = model_max_length + return tokenizer + +def get_module_path(model_id: str): + model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) + module_path = str(model.__class__.__module__) + del model + + return module_path + +def aggregate_outputs_fn(loss_outputs, sync_group) -> AggregatedOutputs: + losses, ntokens_info = [], [] + for _, loss, ntokens, _ in loss_outputs: + losses.append(loss) + ntokens_info.append(ntokens) + + + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + dist.all_reduce(loss_sum, group=sync_group) + + ntokens_sum = torch.sum(torch.tensor(ntokens_info, dtype=torch.float64, device=torch.cuda.current_device())) + dist.all_reduce(ntokens_sum, group=sync_group) + + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + dist.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum=loss_sum.item() / ntokens_sum.item(), + num_batches=num_batches.item(), + num_tokens=ntokens_sum.item(), + ) + + +def load_comm_profile_data(args): + if args.plan_ngpus in [2, 4, 8, 16]: + logger.info(f"Use nnscaler's built-in communication profiling data for {args.plan_ngpus} GPUs") + return + + from nnscaler.autodist.util import get_default_profile_path + profile_dir = os.path.join(get_default_profile_path(), 'comm') + profile_path = os.path.join(profile_dir, f"intra_{args.plan_ngpus}.json") + + if not os.path.exists(profile_path): + import shutil + logger.info(f"Communication profiling data not found in {profile_dir} for {args.plan_ngpus} GPUs. Use built-in communication profiling data (collected on A100-SXM4-40GB)") + src_file_path = os.path.join(BASE_DIR, "utils/comm_prof/NVIDIA_A100-SXM4-40GB", f"intra_{args.plan_ngpus}.json") + if not os.path.exists(src_file_path): + raise FileNotFoundError(f"Communication profiling data not found in {src_file_path} nor in nnscaler's built-in library for {args.plan_ngpus} GPUs") + os.makedirs(profile_dir, exist_ok=True) + + num_dev = 2 + while num_dev <= args.plan_ngpus: + src_file_path = os.path.join(BASE_DIR, "utils/comm_prof/NVIDIA_A100-SXM4-40GB", f"intra_{num_dev}.json") + profile_path = os.path.join(profile_dir, f"intra_{num_dev}.json") + if os.path.exists(profile_path): + logger.info(f"Communication profiling data already exists in {profile_path} for {num_dev} GPUs") + num_dev *= 2 + continue + else: + logger.info(f"Copying {src_file_path} to {profile_path}") + shutil.copy(src_file_path, profile_path) + num_dev *= 2 + + + +def is_active(module_name: str, keep_active: List[str]): + for active_module_subname in keep_active: + if active_module_subname.lower() in module_name.lower(): + return True + return False + +def read_active_param_list(active_param_config_name: str): + print(f"Reading active param list from {active_param_config_name}...") + with open(os.path.join(ACTIVE_PARAM_CONFIG_DIR, f'{active_param_config_name}.txt'), "r") as f: + return f.read().splitlines() + +def freeze_model_params_(model, keep_active: List[str], prefix=""): + if dist.get_rank() == 0: + print("-" * 80) + print(f"Only keeping parameters with substring in {keep_active} active...") + + for name, module in model._modules.items(): + if len(list(module.children())) > 0: + freeze_model_params_(module, keep_active, prefix + name + ".") + else: + param_name = prefix + name + if not is_active(param_name, keep_active): + print(f"Freezing {param_name}...") + for param in module.parameters(): + param.requires_grad = False + else: + print(f"Keeping {param_name} active...") + + if dist.get_rank() == 0: + print("-" * 80) + + +def freeze_model_params(model, active_param_config_name: str, prefix=""): + print(f"active param config name: {active_param_config_name}") + keep_active = read_active_param_list(active_param_config_name) + print(f"keep active: {keep_active}") + + freeze_model_params_(model, keep_active, prefix) diff --git a/mtraining/utils/loss.py b/mtraining/utils/loss.py new file mode 100644 index 0000000..5fd3a54 --- /dev/null +++ b/mtraining/utils/loss.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.utils.checkpoint as ckpt + +from nnscaler.graph.parser.register import register_op + + +def linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: + """ + Compute the cross entropy loss of a linear layer. + + Args: + + x: [token_num, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [token_num], the target token index + padding_idx: int, the index of padding token + + Returns: + + losses: [token_num], the cross entropy loss of each token + """ + logits = torch.nn.functional.linear(x, w) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + losses = torch.nn.functional.nll_loss(normalized_logits, y, reduction='none', ignore_index=padding_idx) + return losses + + +def chunk_linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor: + """ + In order to reduce the memory usage when the sequence length and dictionary size are large, we can split the input + tensor into chunks and compute the cross entropy loss of each chunk separately. + You can register this function with annotation 'b l d^, n^ d^, b l -> b l'. + + Args: + + x: [bsz, seq_len, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [bsz, seq_len], the target token index + padding_idx: int, the index of padding token + chunk_size: int, the size of each chunk + + Returns: + + losses: [bsz, seq_len], the cross entropy loss of each token + """ + bsz, seq_len, hidden_size = x.size() + token_num = bsz * seq_len + x = x.view(token_num, hidden_size) + y = y.view(token_num) + + if token_num % chunk_size != 0: + raise ValueError(f"token_num {token_num} is not divisible by chunk_size {chunk_size}") + + chunk_num = token_num // chunk_size + xs = x.view(chunk_num, chunk_size, hidden_size) + ys = y.view(chunk_num, chunk_size) + losses = [] + for i in range(chunk_num): + loss = ckpt.checkpoint(linear_cross_entropy, xs[i], w, ys[i], padding_idx, use_reentrant=False) + losses.append(loss) + losses = torch.stack(losses).view(bsz, seq_len) + return losses + + +register_op('b l d^, n^ d^, b l -> b l')(chunk_linear_cross_entropy) diff --git a/mtraining/utils/paths.py b/mtraining/utils/paths.py new file mode 100644 index 0000000..ba24d7e --- /dev/null +++ b/mtraining/utils/paths.py @@ -0,0 +1,33 @@ +import os + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +ACTIVE_PARAM_CONFIG_DIR = os.path.join(BASE_DIR, "models", "active_param_configs") + +MINFER_CONFIG_DIR = os.path.join(BASE_DIR, 'sparse_configs') +SPARSE_PATTERN_CONFIG_DIR = os.path.join(BASE_DIR, 'ops', 'minfer', 'configs') +SPARSE_HEAD_MAP_DIR = os.path.join(BASE_DIR, 'ops', 'minfer', 'sparse_head_maps') + + +EXPR_DATA_SAVE_PATH = { + 'base_path': None, + "ckpt_save_path": None, + "compile_save_path": None, +} + +def update_expr_data_by_base_path(base_path): + EXPR_DATA_SAVE_PATH['base_path'] = base_path + +def update_expr_data_save_path( + ckpt_save_path, + compile_save_path, + ): + if ckpt_save_path is None: + EXPR_DATA_SAVE_PATH['base_path'] = os.getenv("EFFI_EXPR_STORE_DIR") + else: + EXPR_DATA_SAVE_PATH['base_path'] = os.path.dirname(ckpt_save_path) + if "rank_" in ckpt_save_path: + EXPR_DATA_SAVE_PATH['base_path'] = os.path.dirname(os.path.dirname(ckpt_save_path)) + + EXPR_DATA_SAVE_PATH['ckpt_save_path'] = ckpt_save_path + EXPR_DATA_SAVE_PATH['compile_save_path'] = compile_save_path \ No newline at end of file From 35ad04d5da245ac90d340575ddc87d160823b108 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Tue, 10 Jun 2025 11:52:19 +0000 Subject: [PATCH 02/12] put away essentials --- .gitignore | 1 + minference/dist_ops/__init__.py | 9 + minference/dist_ops/minfer_striped.py | 97 +-- minference/dist_ops/minfer_striped_triton.py | 22 +- minference/dist_ops/moba_zigzag.py | 44 +- minference/dist_ops/op_utils/moba_utils.py | 283 +-------- minference/dist_ops/xattn_zigzag.py | 49 -- minference/ops/moba.py | 593 ++++++++++++++++++ minference/ops/utils.py | 62 ++ minference/ops/xattention_fa.py | 96 +++ mtraining/attn_funcs/__init__.py | 36 ++ mtraining/attn_funcs/dense_func.py | 278 ++++++++ mtraining/attn_funcs/minfer_func.py | 412 ++++++++++++ mtraining/attn_funcs/moba_func.py | 200 ++++++ mtraining/attn_funcs/utils.py | 49 ++ mtraining/attn_funcs/xattn_func.py | 217 +++++++ mtraining/train.py | 158 ++--- .../train_attn_configs/moba_256k_s95.yaml | 2 + .../train_attn_configs/moba_512k_s95.yaml | 2 + .../train_attn_configs/qwen_flex_090.yaml | 2 + .../train_attn_configs/qwen_flex_095.yaml | 2 + .../train_attn_configs/qwen_mf_dr_stripe.yaml | 2 + .../train_attn_configs/qwen_mf_stripe.yaml | 2 + .../train_attn_configs/qwen_mf_zigzag.yaml | 2 + .../train_attn_configs/xattn_default.yaml | 12 + .../train_attn_configs/xattn_zigzag_s16.yaml | 12 + .../xattn_zigzag_s16_t85.yaml | 12 + mtraining/trainer.py | 189 +----- mtraining/utils/expr_data.py | 15 + mtraining/utils/general.py | 67 ++ mtraining/utils/paths.py | 6 +- 31 files changed, 2213 insertions(+), 720 deletions(-) create mode 100644 minference/ops/moba.py create mode 100644 mtraining/attn_funcs/__init__.py create mode 100644 mtraining/attn_funcs/dense_func.py create mode 100644 mtraining/attn_funcs/minfer_func.py create mode 100644 mtraining/attn_funcs/moba_func.py create mode 100644 mtraining/attn_funcs/utils.py create mode 100644 mtraining/attn_funcs/xattn_func.py create mode 100644 mtraining/train_attn_configs/moba_256k_s95.yaml create mode 100644 mtraining/train_attn_configs/moba_512k_s95.yaml create mode 100644 mtraining/train_attn_configs/qwen_flex_090.yaml create mode 100644 mtraining/train_attn_configs/qwen_flex_095.yaml create mode 100644 mtraining/train_attn_configs/qwen_mf_dr_stripe.yaml create mode 100644 mtraining/train_attn_configs/qwen_mf_stripe.yaml create mode 100644 mtraining/train_attn_configs/qwen_mf_zigzag.yaml create mode 100644 mtraining/train_attn_configs/xattn_default.yaml create mode 100644 mtraining/train_attn_configs/xattn_zigzag_s16.yaml create mode 100644 mtraining/train_attn_configs/xattn_zigzag_s16_t85.yaml create mode 100644 mtraining/utils/expr_data.py diff --git a/.gitignore b/.gitignore index 16033cd..ae5d0fc 100644 --- a/.gitignore +++ b/.gitignore @@ -415,3 +415,4 @@ build/ *.egg-info/ *.so dist +*.eggs/ \ No newline at end of file diff --git a/minference/dist_ops/__init__.py b/minference/dist_ops/__init__.py index e69de29..cc82224 100644 --- a/minference/dist_ops/__init__.py +++ b/minference/dist_ops/__init__.py @@ -0,0 +1,9 @@ +from .minfer_striped import minfer_stripe_func +from .minfer_zigzag import minfer_zigzag_func +from .minfer_dr_striped import minfer_dr_stripe_func + +from .minfer_striped_triton import minfer_stripe_triton_func +from .minfer_dr_stripe_triton import minfer_dr_stripe_triton_func + +from .moba_zigzag import moba_zigzag_func +from .xattn_zigzag import xattn_zigzag_func diff --git a/minference/dist_ops/minfer_striped.py b/minference/dist_ops/minfer_striped.py index 6fb6fcd..ffd7cb1 100644 --- a/minference/dist_ops/minfer_striped.py +++ b/minference/dist_ops/minfer_striped.py @@ -25,81 +25,8 @@ sys.setdlopenflags(original_flags) # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr -def compute_sr_flops( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - step: int, - granularity: int, - q_len: int, - head_dim: int, - shift: bool, - fwd: bool=True, -): - num_blocks = triton.cdiv(q_len, granularity) - bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] - bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dtype=torch.float32).item() - - total_num_blocks = bh * num_blocks * (num_blocks - 1) / 2 - if step == 0: - total_num_blocks += bh * num_blocks / 2 - elif not shift: - total_num_blocks += bh * num_blocks - - if step == 0: - num_active_blocks = block_mask_offset.sum(dim=-1).sum(dtype=torch.float32).item() - bh * num_blocks / 2 - elif not shift: - num_active_blocks = block_mask_offset.sum(dtype=torch.float32).item() - else: - num_active_blocks = block_mask_offset[..., 1:, :-1].sum(dtype=torch.float32).item() - block_ratio = num_active_blocks / total_num_blocks - bar_ratio = bar_cnt_step / (granularity * total_num_blocks) - sparsity_ratio = 1 - block_ratio - bar_ratio - - block_flops = num_active_blocks * (granularity * granularity) * head_dim * 2 * 2 - bar_flops = bar_cnt_step * granularity * head_dim * 2 * 2 - flops = block_flops + bar_flops - - if not fwd: - flops, block_flops, bar_flops = 2.5 * flops, 2.5 * block_flops, 2.5 * bar_flops - # STEP_DATA_FIELDS = ["block_ratio", "bar_ratio", "sparsity_ratio", "blk_flops", "bar_flops", "flops"] - return block_ratio, bar_ratio, sparsity_ratio, block_flops, bar_flops, flops - - -def compute_sr_by_heads( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - step: int, - granularity: int, - q_len: int, - head_dim: int, - shift: bool, - fwd: bool=True, -): - batch_size, num_heads = block_mask_offset.shape[0], block_mask_offset.shape[1] - num_blocks = triton.cdiv(q_len, granularity) - bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) # [num_qo_heads] - - total_num_blocks = num_blocks * (num_blocks - 1) / 2 - if step == 0: - total_num_blocks += num_blocks / 2 - elif not shift: - total_num_blocks += num_blocks - total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) - - if step == 0: - num_active_blocks = block_mask_offset.sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) - batch_size * num_blocks / 2 - elif not shift: - num_active_blocks = block_mask_offset.sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) - else: - num_active_blocks = block_mask_offset[..., 1:, :-1].sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) - block_ratio_by_heads = num_active_blocks / total_num_blocks_by_heads - bar_ratio_by_heads = bar_cnt_step / total_num_blocks_by_heads / granularity - sparsity_ratio_by_heads = 1 - block_ratio_by_heads - bar_ratio_by_heads - - return sparsity_ratio_by_heads.detach().cpu().numpy().tolist() - - -def sparse_stripe_flash_attn_forward( + +def minfer_stripe_forward( process_group: dist.ProcessGroup, q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] @@ -146,7 +73,7 @@ def sparse_stripe_flash_attn_forward( # lse = lse.squeeze(dim=-1).transpose(1, 2) return out, lse, bar_k, bar_v -def sparse_stripe_flash_attn_backward( +def minfer_stripe_backward( process_group: dist.ProcessGroup, dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] @@ -227,7 +154,7 @@ def sparse_stripe_flash_attn_backward( return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) -class SparseStripeFlashAttnFunc(torch.autograd.Function): +class MInferStripeFunc(torch.autograd.Function): @staticmethod def forward( ctx, @@ -257,7 +184,7 @@ def forward( v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) # Compute - out, softmax_lse, bar_k, bar_v = sparse_stripe_flash_attn_forward( + out, softmax_lse, bar_k, bar_v = minfer_stripe_forward( group, q, k, v, layer_idx, softmax_scale, block_mask, bar_idx, bar_cnt, v_idx, v_cnt, @@ -293,7 +220,7 @@ def backward(ctx, dout, *args): dout = shuffle_striped_input(to_send=dout, dim=1, granularity=granularity, process_group=group) # Compute - dq, dk, dv = sparse_stripe_flash_attn_backward( + dq, dk, dv = minfer_stripe_backward( group, dout, q, k, v, out, softmax_lse, layer_idx, softmax_scale, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v, @@ -306,7 +233,7 @@ def backward(ctx, dout, *args): dv = recover_striped_output(dv, dim=1, granularity=granularity, process_group=group) return dq, dk, dv, None, None, None, None, None, None, None -def sparse_stripe_flash_attn_qkvpacked_func( +def minfer_stripe_qkvpacked_func( qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] v_size: List[int], # [num_heads] s_size: List[int], # [num_heads] @@ -326,7 +253,7 @@ def sparse_stripe_flash_attn_qkvpacked_func( assert window_size == (-1, -1) assert alibi_slopes is None assert not deterministic - return SparseStripeFlashAttnFunc.apply( + return MInferStripeFunc.apply( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], @@ -340,7 +267,7 @@ def sparse_stripe_flash_attn_qkvpacked_func( ) -def sparse_stripe_flash_attn_kvpacked_func( +def minfer_stripe_kvpacked_func( q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] v_size: List[int], # [num_heads] @@ -361,7 +288,7 @@ def sparse_stripe_flash_attn_kvpacked_func( assert window_size == (-1, -1) assert alibi_slopes is None assert not deterministic - return SparseStripeFlashAttnFunc.apply( + return MInferStripeFunc.apply( q, kv[:, :, 0], kv[:, :, 1], @@ -374,7 +301,7 @@ def sparse_stripe_flash_attn_kvpacked_func( group, ) -def sparse_stripe_flash_attn_func( # the one used for nnscaler training +def minfer_stripe_func( # the one used for nnscaler training q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] @@ -397,7 +324,7 @@ def sparse_stripe_flash_attn_func( # the one used for nnscaler training assert alibi_slopes is None assert not deterministic - return SparseStripeFlashAttnFunc.apply( + return MInferStripeFunc.apply( q, k, v, diff --git a/minference/dist_ops/minfer_striped_triton.py b/minference/dist_ops/minfer_striped_triton.py index e3b5eb2..77dda4a 100644 --- a/minference/dist_ops/minfer_striped_triton.py +++ b/minference/dist_ops/minfer_striped_triton.py @@ -11,7 +11,7 @@ from minference.ops.minference_attn_triton import block_bar_attn_fwd, block_bar_attn_bwd -def sparse_stripe_flash_attn_triton_forward( +def minfer_stripe_triton_forward( process_group: dist.ProcessGroup, q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] @@ -50,7 +50,7 @@ def sparse_stripe_flash_attn_triton_forward( return out, lse -def sparse_stripe_flash_attn_triton_backward( +def minfer_stripe_triton_backward( process_group: dist.ProcessGroup, dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] @@ -114,7 +114,7 @@ def sparse_stripe_flash_attn_triton_backward( return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) -class SparseStripeFlashAttnTritonFunc(torch.autograd.Function): +class MInferStripeTritonFunc(torch.autograd.Function): @staticmethod def forward( ctx, @@ -142,7 +142,7 @@ def forward( v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) # slash attn - out, softmax_lse = sparse_stripe_flash_attn_triton_forward( + out, softmax_lse = minfer_stripe_triton_forward( group, q, k, v, layer_idx, softmax_scale, block_idx, block_cnt, bar_idx, bar_cnt, @@ -168,7 +168,7 @@ def backward(ctx, dout, *args): dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors - dq, dk, dv = sparse_stripe_flash_attn_triton_backward( + dq, dk, dv = minfer_stripe_triton_backward( ctx.group, dout, q, k, v, out, softmax_lse, layer_idx, ctx.softmax_scale, block_idx, block_cnt, bar_idx, bar_cnt, @@ -182,7 +182,7 @@ def backward(ctx, dout, *args): return dq, dk, dv, None, None, None, None, None, None, None -def sparse_stripe_flash_attn_triton_qkvpacked_func( +def minfer_stripe_triton_qkvpacked_func( qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] v_size: List[int], # [num_heads] s_size: List[int], # [num_heads] @@ -201,7 +201,7 @@ def sparse_stripe_flash_attn_triton_qkvpacked_func( assert window_size == (-1, -1) assert alibi_slopes is None assert not deterministic - return SparseStripeFlashAttnTritonFunc.apply( + return MInferStripeTritonFunc.apply( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], @@ -214,7 +214,7 @@ def sparse_stripe_flash_attn_triton_qkvpacked_func( ) -def sparse_stripe_flash_attn_triton_kvpacked_func( +def minfer_stripe_triton_kvpacked_func( q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] v_size: List[int], # [num_heads] @@ -234,7 +234,7 @@ def sparse_stripe_flash_attn_triton_kvpacked_func( assert window_size == (-1, -1) assert alibi_slopes is None assert not deterministic - return SparseStripeFlashAttnTritonFunc.apply( + return MInferStripeTritonFunc.apply( q, kv[:, :, 0], kv[:, :, 1], @@ -247,7 +247,7 @@ def sparse_stripe_flash_attn_triton_kvpacked_func( ) -def sparse_stripe_flash_attn_triton_func( +def minfer_stripe_triton_func( q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] @@ -270,7 +270,7 @@ def sparse_stripe_flash_attn_triton_func( assert alibi_slopes is None assert not deterministic - return SparseStripeFlashAttnTritonFunc.apply( + return MInferStripeTritonFunc.apply( q, k, v, diff --git a/minference/dist_ops/moba_zigzag.py b/minference/dist_ops/moba_zigzag.py index e8834b0..c60b981 100644 --- a/minference/dist_ops/moba_zigzag.py +++ b/minference/dist_ops/moba_zigzag.py @@ -985,8 +985,13 @@ def backward(ctx, dout, *args): return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None -def zigzag_ring_flash_attn_qkvpacked_func( +def moba_zigzag_qkvpacked_func( qkv, + seq_offset: torch.Tensor, + layer_idx: int, + cu_seqlens, + moba_chunk_size, + moba_topk, dropout_p=0.0, softmax_scale=None, causal=False, @@ -1000,8 +1005,13 @@ def zigzag_ring_flash_attn_qkvpacked_func( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], + seq_offset, + layer_idx, dropout_p, softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, causal, window_size, alibi_slopes, @@ -1011,9 +1021,14 @@ def zigzag_ring_flash_attn_qkvpacked_func( ) -def zigzag_ring_flash_attn_kvpacked_func( +def moba_zigzag_kvpacked_func( q, kv, + seq_offset: torch.Tensor, + layer_idx: int, + cu_seqlens, + moba_chunk_size, + moba_topk, dropout_p=0.0, softmax_scale=None, causal=False, @@ -1027,8 +1042,13 @@ def zigzag_ring_flash_attn_kvpacked_func( q, kv[:, :, 0], kv[:, :, 1], + seq_offset, + layer_idx, dropout_p, softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, causal, window_size, alibi_slopes, @@ -1038,10 +1058,13 @@ def zigzag_ring_flash_attn_kvpacked_func( ) -def zigzag_ring_flash_attn_func( - q, - k, - v, +def moba_zigzag_func( + q, k, v, + seq_offset: torch.Tensor, + layer_idx: int, + cu_seqlens, + moba_chunk_size, + moba_topk, dropout_p=0.0, softmax_scale=None, causal=False, @@ -1052,11 +1075,14 @@ def zigzag_ring_flash_attn_func( group=None, ): return MoBAZigzagRingFlashAttnFunc.apply( - q, - k, - v, + q, k, v, + seq_offset, + layer_idx, dropout_p, softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, causal, window_size, alibi_slopes, diff --git a/minference/dist_ops/op_utils/moba_utils.py b/minference/dist_ops/op_utils/moba_utils.py index 7425646..8c981f1 100644 --- a/minference/dist_ops/op_utils/moba_utils.py +++ b/minference/dist_ops/op_utils/moba_utils.py @@ -11,184 +11,15 @@ import pandas as pd import torch.nn.functional as F import torch.distributed as dist +from dataclasses import dataclass from functools import reduce, cache, lru_cache from typing import Optional, Tuple, List, Dict -@cache -def _get_default_args(func): - spec = inspect.getfullargspec(func) - defaults = spec.defaults if spec.defaults is not None else () - padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults - args = dict(zip(spec.args, padded_defaults)) - if "softcap" in args: - args["softcap"] = 0.0 - return args - - -def get_default_args(func): - if inspect.isfunction(func): - return _get_default_args(func) - else: - # Use the origin _init_fn in CustomOpDef - return _get_default_args(func._init_fn) - - -# copy from megatron/core/utils.py -class GlobalMemoryBuffer: - """Global buffer to avoid dynamic memory allocations. - Caller should ensure that buffers of the same name - are not used concurrently.""" - - def __init__(self): - self.buffer = {} - - def get_tensor(self, tensor_shape, dtype, name): - required_len = reduce(operator.mul, tensor_shape, 1) - if ( - self.buffer.get((name, dtype), None) is None - or self.buffer[(name, dtype)].numel() < required_len - ): - self.buffer[(name, dtype)] = torch.empty( - required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False - ) - - return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) - -def check_nan_inf( - out, lse, block_out, block_lse, phase_prefix: str, postfix: str -): - if (not torch.isnan(block_out).any()) and (not torch.isnan(block_lse).any()): - if torch.isnan(out).any(): - print(f"{phase_prefix}nan in out ({postfix})") - if torch.isinf(out).any(): - print(f"{phase_prefix}inf in out ({postfix})") - - if torch.isnan(lse).any(): - print(f"{phase_prefix}nan in lse ({postfix})") - if torch.isinf(lse).any(): - print(f"{phase_prefix}inf in lse ({postfix})") - -@torch.jit.script -def _update_out_and_lse( - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - - block_out = block_out.to(torch.float32) - block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - - # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out - # For additional context and discussion, please refer to: - # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - - return out, lse - - -def update_out_and_lse( - out: Optional[torch.Tensor], - lse: Optional[torch.Tensor], - block_out: torch.Tensor, - block_lse: torch.Tensor, - slice_=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - if out is None: - if slice_ is not None: - raise RuntimeError("first update_out_and_lse should not pass slice_ args") - out = block_out.to(torch.float32) - lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - elif slice_ is not None: - slice_out, slice_lse = out[slice_], lse[slice_] - - slice_out, slice_lse = _update_out_and_lse( - slice_out, slice_lse, block_out, block_lse - ) - out[slice_], lse[slice_] = slice_out, slice_lse - else: - out, lse = _update_out_and_lse(out, lse, block_out, block_lse) - return out, lse - -class RingComm: - def __init__(self, process_group: dist.ProcessGroup): - self._process_group = process_group - self._ops = [] - self.rank = dist.get_rank(self._process_group) - self.world_size = dist.get_world_size(self._process_group) - self._reqs = None - - parts = self.world_size // 2 - self.ring_list = [] - for i in range(parts): - self.ring_list.extend([i, self.world_size - i - 1]) - - self.revert_rank = self.ring_list.index(self.rank) - - offset = ((dist.get_rank() // self.world_size) * self.world_size) - self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset - self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset - - def send_recv( - self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None - ) -> torch.Tensor: - if recv_tensor is None: - res = torch.empty_like(to_send) - else: - res = recv_tensor - - send_op = dist.P2POp( - dist.isend, to_send, self.send_rank, group=self._process_group - ) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - return res - - def commit(self): - if self._reqs is not None: - raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) - - def wait(self): - if self._reqs is None: - raise RuntimeError("wait called before commit") - - for req in self._reqs: - req.wait() - - self._reqs = None - self._ops = [] - - def send_recv_kv( - self, - k: torch.Tensor, - v: torch.Tensor, - k_buffer: Optional[torch.Tensor] = None, - v_buffer: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) - self.commit() - return next_k, next_v - - def send_recv_kv_offsets( - self, - k: torch.Tensor, - v: torch.Tensor, - kv_seq_offsets: torch.Tensor, - k_buffer: Optional[torch.Tensor] = None, - v_buffer: Optional[torch.Tensor] = None, - kv_seq_offsets_buffer: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) - next_kv_seq_offsets = self.send_recv(kv_seq_offsets, kv_seq_offsets_buffer) - - self.commit() - return next_k, next_v, next_kv_seq_offsets - +@dataclass +class MoBAConfig: + moba_chunk_size: int + moba_topk: int def shuffle_input_all( to_send: torch.Tensor, # [S, H, D] @@ -349,110 +180,6 @@ def shuffle_input_only( to_send_f[:, block_seq_len:, ...] = res return to_send_f if orig_ndim != 3 else to_send_f.squeeze(0) - -def recover_output( - to_send: torch.Tensor, # [S, H, D] - process_group: dist.ProcessGroup = None - ): - orig_ndim = to_send.ndim - if orig_ndim == 3: to_send = to_send.unsqueeze(0) - - if not to_send.is_contiguous(): - to_send = to_send.contiguous() - - to_send_f = torch.zeros_like(to_send) - - block_seq_len = to_send.shape[1] // 2 - - rank = dist.get_rank(process_group) - world_size = dist.get_world_size(process_group) - - if rank >= world_size // 2: - to_send_slice = to_send[:, :block_seq_len, ...].contiguous() - else: - to_send_slice = to_send[:, block_seq_len:, ...].contiguous() - res = torch.zeros_like(to_send_slice) - - assert to_send_slice.is_contiguous() - assert res.is_contiguous() - - _ops = [] - offset = ((dist.get_rank() // world_size) * world_size) - src_rank = (world_size - rank - 1) % world_size + offset - send_op = dist.P2POp( - dist.isend, to_send_slice, src_rank, group=process_group - ) - recv_op = dist.P2POp( - dist.irecv, res, src_rank, group=process_group) - - _ops.append(send_op) - _ops.append(recv_op) - - response = dist.batch_isend_irecv(_ops) - for resp in response: - resp.wait() - - if rank >= world_size // 2: - to_send_f[:, :block_seq_len] = to_send[:, block_seq_len:, ...] - to_send_f[:, block_seq_len:] = res - else: - to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len, ...] - to_send_f[:, block_seq_len:] = res - - return to_send_f.contiguous() if orig_ndim != 3 else to_send_f.squeeze(0).contiguous() - - - -def recover_lse( - to_send_lse: torch.Tensor, # [H, S] - process_group: dist.ProcessGroup = None - ): - - if not to_send_lse.is_contiguous(): - to_send_lse = to_send_lse.contiguous() - - to_send_f = torch.zeros_like(to_send_lse) - - block_seq_len = to_send_lse.shape[1] // 2 - - rank = dist.get_rank(process_group) - world_size = dist.get_world_size(process_group) - - if rank >= world_size // 2: - to_send_slice = to_send_lse[:, :block_seq_len].contiguous() - else: - to_send_slice = to_send_lse[:, block_seq_len:].contiguous() - res = torch.zeros_like(to_send_slice) - - assert to_send_slice.is_contiguous() - assert res.is_contiguous() - - _ops = [] - offset = ((dist.get_rank() // world_size) * world_size) - src_rank = (world_size - rank - 1) % world_size + offset - send_op = dist.P2POp( - dist.isend, to_send_slice, src_rank, group=process_group - ) - recv_op = dist.P2POp( - dist.irecv, res, src_rank, group=process_group) - - _ops.append(send_op) - _ops.append(recv_op) - - response = dist.batch_isend_irecv(_ops) - for resp in response: - resp.wait() - - if rank >= world_size // 2: - to_send_f[:, :block_seq_len] = to_send_lse[:, block_seq_len:] - to_send_f[:, block_seq_len:] = res - else: - to_send_f[:, :block_seq_len] = to_send_lse[:, :block_seq_len] - to_send_f[:, block_seq_len:] = res - - return to_send_f.contiguous() - - @lru_cache(maxsize=16) def calc_chunks(cu_seqlen, moba_chunk_size): """calc chunks that needs moba attention""" diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py index 7fe7929..b92d011 100644 --- a/minference/dist_ops/xattn_zigzag.py +++ b/minference/dist_ops/xattn_zigzag.py @@ -161,55 +161,6 @@ def xattn_zigzag_estimate( simple_masks = torch.cat(simple_mask_list, dim=-2) # (batch_size, head_num, q_local_block_num, k_global_block_num) return attn_sums, simple_masks - -def compute_sr_flops( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - step: int, - granularity: int, - q_len: int, - head_dim: int, - fwd: bool=True, -): - num_blocks = triton.cdiv(q_len, granularity) - bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] - - total_num_blocks = bh * num_blocks * num_blocks / 2 - - num_active_blocks = block_mask_offset.sum(dtype=torch.float32).item() - if step == 0: - num_active_blocks -= bh * num_blocks / 2 - - block_ratio = num_active_blocks / total_num_blocks - sparsity_ratio = 1 - block_ratio - - block_flops = num_active_blocks * (granularity * granularity) * head_dim * 2 * 2 - - if not fwd: block_flops *= 2.5 - return sparsity_ratio, block_flops - - -def compute_sr_by_heads( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - step: int, - granularity: int, - q_len: int, -): - batch_size, num_heads = block_mask_offset.shape[0], block_mask_offset.shape[1] - num_blocks = triton.cdiv(q_len, granularity) - - total_num_blocks = batch_size * num_blocks * num_blocks / 2 - total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) - - - num_active_blocks = block_mask_offset.sum(-1).sum(-1).sum(0, dtype=torch.float32) # [num_qo_heads] - if step == 0: - num_active_blocks -= batch_size * num_blocks / 2 - - block_ratio_by_heads = num_active_blocks / total_num_blocks_by_heads - sparsity_ratio_by_heads = 1 - block_ratio_by_heads - - return sparsity_ratio_by_heads.detach().cpu().numpy().tolist() - def use_triton(): return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" diff --git a/minference/ops/moba.py b/minference/ops/moba.py new file mode 100644 index 0000000..7743eeb --- /dev/null +++ b/minference/ops/moba.py @@ -0,0 +1,593 @@ +"""A clean version of moba implementation for educational purposes""" +import math +import torch + +from einops import rearrange +from typing import Union, Tuple, Callable, Optional +from flash_attn import flash_attn_varlen_func, flash_attn_func +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) + +from .utils import calc_chunks + +def hf_to_fa(x: torch.Tensor): + """ + Args: + x (torch.Tensor): [batch, heads, seqlen, head_dim] + + Returns: + torch.Tensor: [batch * seqlen, heads, head_dim] + """ + return x.permute(0, 2, 1, 3).reshape(-1, x.shape[1], x.shape[3]) + + +def moba_attn_varlen_naive( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + moba_chunk_size: int, + moba_topk: int, +) -> torch.Tensor: + """Implement the moba brute-force setting for reference + + Args: + q (torch.Tensor): [seqlen, head, head_dim] + k (torch.Tensor): [seqlen, head, head_dim] + v (torch.Tensor): [seqlen, head, head_dim] + cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn + max_seqlen (int): the max sequence length of the batch, same definition in flash attn + + Returns: + attn_output (torch.Tensor): [seqlen, head, head_dim] + """ + + # qkv shape = [ S, H, D ] + batch = cu_seqlens.numel() - 1 + softmax_scale = q.shape[-1] ** (-0.5) + + o = torch.zeros_like(q) + for batch_idx in range(batch): + batch_start = cu_seqlens[batch_idx].item() + batch_end = cu_seqlens[batch_idx + 1].item() + # get qkv of this batch + q_ = q[batch_start:batch_end] + k_ = k[batch_start:batch_end] + v_ = v[batch_start:batch_end] + o_ = o[batch_start:batch_end] + # calc key gate weight + key_gate_weight = [] + batch_size = batch_end - batch_start + num_block = math.ceil(batch_size / moba_chunk_size) + for block_idx in range(0, num_block): + block_start = block_idx * moba_chunk_size + block_end = min(batch_size, block_start + moba_chunk_size) + key_gate_weight.append(k_[block_start:block_end].mean(dim=0, keepdim=True)) + key_gate_weight = torch.cat(key_gate_weight, dim=0) # [ N, H, D ] + # calc & mask gate + # use fp32 to avoid precision issue in bf16 + q_ = q_.type(torch.float32) + key_gate_weight = key_gate_weight.type(torch.float32) + gate = torch.einsum("shd,nhd->hsn", q_, key_gate_weight) # [ H, S, N ] + key_gate_weight = key_gate_weight.type_as(k) + q_ = q_.type_as(k) + for i in range(num_block): + # select the future Qs that can attend to KV chunk i + gate[:, : (i + 1) * moba_chunk_size, i] = float("-inf") + gate[:, i * moba_chunk_size : (i + 1) * moba_chunk_size, i] = float("inf") + # gate_top_k_idx = gate_top_k_val = [ H S K ] + gate_top_k_val, gate_top_k_idx = torch.topk( + gate, k=min(moba_topk, num_block), dim=-1, largest=True, sorted=False + ) + gate_top_k_val, _ = gate_top_k_val.min(dim=-1) # [ H, S ] + need_attend = gate >= gate_top_k_val.unsqueeze(-1) + # add gate_idx_mask in case of there is cornercases of same topk val been selected + gate_idx_mask = torch.zeros( + need_attend.shape, dtype=torch.bool, device=q.device + ) + gate_idx_mask = gate_idx_mask.scatter_(dim=-1, index=gate_top_k_idx, value=True) + need_attend = torch.logical_and(need_attend, gate_idx_mask) + gate[need_attend] = 0 + gate[~need_attend] = -float("inf") + gate = gate.repeat_interleave(moba_chunk_size, dim=-1)[ + :, :, :batch_size + ] # [ H, S, S ] + gate.masked_fill_( + torch.ones_like(gate, dtype=torch.bool).tril().logical_not(), -float("inf") + ) + # print(f"moba_naive | gate ({gate.shape}): {gate}") + + # calc qk = qk^t + q_ = q_.type(torch.float32) + k_ = k_.type(torch.float32) + v_ = v_.type(torch.float32) + qk = torch.einsum("xhd,yhd->hxy", q_, k_) + # mask + qk += gate + qk *= softmax_scale + # calc o + p = qk.softmax(dim=-1) + o_ += torch.einsum("hxy,yhd->xhd", p, v_) + o = o.type_as(q) + + return o + + + + +class MixedAttention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + return_lse, + ): + ctx.max_seqlen = max_seqlen + ctx.moba_chunk_size = moba_chunk_size + ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5) + + # self attn + _, _, _, _, self_attn_out_sh, self_attn_lse_hs, _, _ = ( + _flash_attn_varlen_forward( + q=q, + k=k, + v=v, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + ) + ) + + _, _, _, _, moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + + # convert lse shape hs -> sh ( follow the legacy mix attn logic ) + self_attn_lse_sh = self_attn_lse_hs.t().contiguous() + moba_attn_lse = moba_attn_lse_hs.t().contiguous() + + # output buffer [S, H, D], same shape as q + output = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + + # flatten vS & H for index ops + output_2d = output.view(-1, q.shape[2]) + + # calc mixed_lse + # minus max lse to avoid exp explosion + max_lse_1d = self_attn_lse_sh.view(-1) + max_lse_1d = max_lse_1d.index_reduce( + 0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax" + ) + self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh) + moba_attn_lse = ( + moba_attn_lse.view(-1) + .sub(max_lse_1d.index_select(0, moba_q_sh_indices)) + .reshape_as(moba_attn_lse) + ) + + mixed_attn_se_sh = self_attn_lse_sh.exp() + moba_attn_se = moba_attn_lse.exp() + + mixed_attn_se_sh.view(-1).index_add_( + 0, moba_q_sh_indices, moba_attn_se.view(-1) + ) + mixed_attn_lse_sh = mixed_attn_se_sh.log() + + # add attn output + factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [ vS, H ] + self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1) + output_2d += self_attn_out_sh.reshape_as(output_2d) + + # add moba output + mixed_attn_lse = ( + mixed_attn_lse_sh.view(-1) + .index_select(0, moba_q_sh_indices) + .view_as(moba_attn_lse) + ) + factor = (moba_attn_lse - mixed_attn_lse).exp() # [ vS, H ] + moba_attn_out = moba_attn_out * factor.unsqueeze(-1) + raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1]) + output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out) + output = output.to(q.dtype) + + + # add back max lse + mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh) + + + ctx.save_for_backward( + output, + mixed_attn_lse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) + ctx.return_lse = return_lse + + if return_lse: + return output, mixed_attn_lse_sh + else: + return output + + @staticmethod + def backward(ctx, d_output, *args): + + max_seqlen = ctx.max_seqlen + moba_chunk_size = ctx.moba_chunk_size + softmax_scale = ctx.softmax_scale + + ( + output, + mixed_attn_vlse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) = ctx.saved_tensors + + d_output = d_output.contiguous() + + dq, dk, dv, _ = _flash_attn_varlen_backward( + dout=d_output, + q=q, + k=k, + v=v, + out=output, + softmax_lse=mixed_attn_vlse_sh.t().contiguous(), + dq=None, + dk=None, + dv=None, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=True, + ) + + headdim = q.shape[-1] + d_moba_output = ( + d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + moba_output = ( + output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + + mixed_attn_vlse = ( + mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1) + ) + + dmq, dmk, dmv, _ = _flash_attn_varlen_backward( + dout=d_moba_output, + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + out=moba_output, + softmax_lse=mixed_attn_vlse, + dq=None, + dk=None, + dv=None, + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=True, + ) + + dmkv = torch.stack((dmk, dmv), dim=1) + return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None, None + + +def moba_attn_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + moba_chunk_size: int, + moba_topk: int, + return_lse: bool=False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """An efficient version of moba implementation with triton kernels and flash-attn, the core logic: + 1. Calculate the chunks and the number of chunks, n = floor(data_size / chunk_size) + - tokens in the tail chunk are reserved for self attn + - tokens in other chunks will be processed in later steps + 2. K in each chunk will calculate mean value as the representative k, and Q will attend to these representative + k to get the gate logit, which will be used to select topk chunks + 3. Select the topk chunks and get the dense q for each kv chunk pair and do the varlen attention + 4. Combine the varlen attn and self attn results via online softmax to get the final result + + Args: + q (torch.Tensor): [seqlen, head, head_dim] + k (torch.Tensor): [seqlen, head, head_dim] + v (torch.Tensor): [seqlen, head, head_dim] + cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn + max_seqlen (int): the max sequence length of the batch, same definition in flash attn + + Returns: + attn_output (torch.Tensor): [seqlen, head, head_dim] + """ + print(f"moba_attn_varlen | cu_seqlens: {cu_seqlens}, max_seqlen: {max_seqlen}, moba_chunk_size: {moba_chunk_size}, moba_topk: {moba_topk}, return_lse: {return_lse}") + # --------------------------------------------------------------------------------------------- + kv = torch.stack((k, v), dim=1) # stack along a new dimension -> [S, 2, H, D] + + """ some basic variables """ + # qkv shape = [ S, H, D ] + seqlen, num_head, head_dim = q.shape + + """ prepare chunk meta """ + ( + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + filtered_chunk_indices, + chunk_to_batch, + ) = calc_chunks(cu_seqlens, moba_chunk_size) + + # we will adjust selective topk to moba_topk - 1, as the last chunk is always chosen + moba_topk = min(moba_topk - 1, num_filtered_chunk) + need_moba_attn = moba_topk > 0 + + # corner case: if no moba attn needed, just return self attn + if not need_moba_attn: + return flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=True + ) + + self_attn_cu_seqlen = cu_chunk + + # filtered_kv is a dense matrix that only contains filtered chunk of kv + filtered_kv_indices = torch.arange( + 0, moba_chunk_size, dtype=torch.int32, device=q.device + )[None, :].repeat(num_filtered_chunk, 1) + filtered_kv_indices += cu_chunk[filtered_chunk_indices][:, None] + + # select the elements of KV corresponding to all chunks that are not filtered out + filtered_kv = kv.index_select(0, filtered_kv_indices.view(-1)) + + """ calc key_gate_weight and gate """ + # key_gate_weight [ F_N_CHUNK, HEAD, HEAD_DIM ] + key_gate_weight = ( + filtered_kv[:, 0] # K + .view(num_filtered_chunk, moba_chunk_size, num_head, head_dim) + .mean(dim=1) # mean pooling along chunk size + .float() + ) + q = q.type(torch.float32) # float logit on the fly for better gate logit perception + key_gate_weight = key_gate_weight.type( + torch.float32 + ) # float logit for better gate logit perception + gate = torch.einsum( + "nhd,shd->nhs", key_gate_weight, q + ) # gate [ F_N_CHUNK, HEAD, SEQ ] + key_gate_weight = key_gate_weight.type_as(k) + q = q.type_as(k) + + # pose process gate, masking unchosen batch and apply causal mask to current chunk + gate_seq_idx = torch.arange(0, seqlen, device=q.device, dtype=torch.int32)[ + None, : + ].repeat(num_filtered_chunk, 1) + chunk_end = cu_chunk[filtered_chunk_indices + 1] + batch_end = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1] + gate_chunk_end_mask = gate_seq_idx < chunk_end[:, None] + gate_batch_end_mask = gate_seq_idx >= batch_end[:, None] + gate_inf_mask = gate_chunk_end_mask | gate_batch_end_mask + gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf")) + + """ find moba q that needs moba attn """ + # find topk chunks + # gate_mask [ N_CHUNK, HEAD, SEQ ], true indicates that needs attention + _, gate_top_k_idx = torch.topk(gate, k=moba_topk, dim=0, largest=True, sorted=False) + + # apply causal mask + gate_mask = torch.logical_not(gate.isinf()) + + # select topk chunks + gate_idx_mask = torch.zeros(gate_mask.shape, dtype=torch.bool, device=q.device) + gate_idx_mask = gate_idx_mask.scatter_(dim=0, index=gate_top_k_idx, value=True) + + # gate_mask has the shape [ N_CHUNK, HEAD, SEQ ]. + # For each chunk, the sequence-dimension indices will be True if it belongs to the top-K chunks + gate_mask = torch.logical_and(gate_mask, gate_idx_mask) + # --------------------------------------------------------------------------------------------- + + + # varlen trick: combining all q index that needs moba attn + # the result will be like [ C0H0 ][ C0H1 ][ C0H2 ][ ... ][ CnHm ] + # torch.nonzero (as_tuple=True): Returns a tuple of 1-D tensors, one for each dimension in input, each containing the indices (in that dimension) of all non-zero elements of input . + # if input has n-dimension, the resulting tuple will have n tensors of size z, where z is the number of non-zero elements in input. + # (i-th values of all n tuple elements represent the indices of the i-th non-zero element in each dimension) + # using index [-1] => indices of HS (combined sequence) dimension that contains non-zero elements + moba_q_indices = gate_mask.reshape(gate_mask.shape[0], -1) # [ N, HS ] + + moba_q_indices = moba_q_indices.nonzero(as_tuple=True)[-1] # [HS indices] * N (total size: all non-zero elements in HS dimension) + + + # moba_seqlen_q indicates that how many q chunks are selected for each kv chunk - head + moba_seqlen_q = gate_mask.sum(dim=-1).flatten() + + # select all q that needs moba attn based on the moba_q_indices + moba_q = rearrange(q, "s h d -> ( h s ) d").index_select( + 0, moba_q_indices + ) # [ selected_S, D ] + moba_q = moba_q.unsqueeze(1) + + # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q + moba_q_sh_indices = moba_q_indices % seqlen * num_head + moba_q_indices // seqlen + + """ prepare moba kv """ + # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q + + # cut off zero experts + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + zero_expert_count = q_zero_mask.sum() + + # only keep the kv that has q select > 0 + if zero_expert_count > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + + + # moba cu_seqlen for flash attn + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + # ----------------------------------------------- + print(f"filtered_kv shape: {filtered_kv.shape}") + moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # here `x` only stands for a dimension (stack dimension for KV) + + moba_kv = moba_kv.split(moba_chunk_size, dim=1) + moba_kv = torch.cat(moba_kv, dim=0) # [num_selected_chunks, H x S // moba_chunk_size, D] + + if zero_expert_count > 0: + assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count + moba_kv = moba_kv[ + valid_expert_mask + ] # cut off zero Q expert from kv , or the grad may be nan + + moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2) + + moba_cu_seqlen_kv = ( + torch.arange( + 0, + num_filtered_chunk * num_head + 1 - zero_expert_count, + dtype=torch.int32, + device=q.device, + ) + * moba_chunk_size + ) + + # Shape check + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"moba_cu_seqlen_kv.shape != moba_cu_seqlen_q.shape {moba_cu_seqlen_kv.shape} != {moba_cu_seqlen_q.shape}" + + # Wrapping up the flash attn call and online softmax dlse inside MixedAttention class + return MixedAttention.apply( + q, k, v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + return_lse + ) + + + + +def moba_layer( + moba_impl: Callable, + moba_chunk_size: int, + moba_topk: int, + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + """ + Args: + query (torch.Tensor): [batch, q_heads, q_len, head_dim] + key (torch.Tensor): [batch, kv_heads, kv_len, head_dim] + value (torch.Tensor): [batch, kv_heads, kv_len, head_dim] + + Returns: + attn_output (torch.Tensor): [batch, q_len, q_heads, head_dim] + attn_weights (None): not needed + """ + assert module.is_causal + batch, q_heads, q_len, head_dim = query.shape + _, kv_heads, kv_len, _ = key.shape + if q_len == kv_len: + # prefill phase + query = hf_to_fa(query) + key = hf_to_fa(key) + value = hf_to_fa(value) + kv_replicas = q_heads // kv_heads + key = torch.repeat_interleave(key, kv_replicas, dim=1) + value = torch.repeat_interleave(value, kv_replicas, dim=1) + cu_seqlens_k = torch.cumsum( + torch.tensor([0] + [kv_len] * batch, device=query.device), + dim=0, + dtype=torch.int32, + ) + out = moba_impl( + q=query, + k=key, + v=value, + cu_seqlens=cu_seqlens_k, + max_seqlen=kv_len, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + ) + else: + # decode phase + # TODO release paged attn implementation + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + out = flash_attn_func(query, key, value, dropout, scaling, True) + return out, None diff --git a/minference/ops/utils.py b/minference/ops/utils.py index 8e64722..68503ef 100644 --- a/minference/ops/utils.py +++ b/minference/ops/utils.py @@ -1,6 +1,7 @@ import os import numpy as np from typing import List +from functools import lru_cache import torch import torch.distributed as dist @@ -930,3 +931,64 @@ def get_compute_sparsity( compute_sparsity = 1 - block_ratio - bar_ratio return compute_sparsity + + + +@lru_cache(maxsize=16) +def calc_chunks(cu_seqlen, moba_chunk_size): + """calc chunks that needs moba attention""" + + # batch_sizes[batch_idx] = batch size ( seqlen ) of batch idx + batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] + + # batch_num_chunk[batch_idx] = how many chunk in batch idx + batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size + + # cu_num_chunk[batch_idx] = first chunk id of this batch + cu_num_chunk = torch.ones( + batch_num_chunk.numel() + 1, + device=cu_seqlen.device, + dtype=batch_num_chunk.dtype, + ) + cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0) + + # total chunk ( for all batch ) + num_chunk = cu_num_chunk[-1] + + # chunk_sizes[chunk_idx] = chunk_size of chunk idx + chunk_sizes = torch.full( + (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device + ) + chunk_sizes[0] = 0 # for calc cu chunk + batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size + chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size + + # cu_chunk[chunk_idx] = the start chunk offset of chunk idx + cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32) + + # chunk_to_batch[chunk_idx] = batch idx of the chunk idx + chunk_to_batch = torch.zeros( + (num_chunk,), dtype=torch.int32, device=cu_seqlen.device + ) + chunk_to_batch[cu_num_chunk[1:-1]] = 1 + chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32) + + """ filter chunks that need moba attn """ + + # filter chunks ( remove last chunk of each batch ) + # filtered_chunk_indices: chunk index list that excludes the last chunk of each batch + chunk_to_remove = cu_num_chunk[1:] - 1 + chunk_to_remain = torch.ones( + (num_chunk, ), dtype=torch.bool, device=cu_seqlen.device + ) + chunk_to_remain[chunk_to_remove] = False + filtered_chunk_indices = chunk_to_remain.nonzero(as_tuple=True)[0] + num_filtered_chunk = len(filtered_chunk_indices) + + return ( + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + filtered_chunk_indices, + chunk_to_batch, + ) \ No newline at end of file diff --git a/minference/ops/xattention_fa.py b/minference/ops/xattention_fa.py index 549d3fd..b24f829 100644 --- a/minference/ops/xattention_fa.py +++ b/minference/ops/xattention_fa.py @@ -5,7 +5,10 @@ import torch import triton import triton.language as tl +from typing import List, Tuple, Dict, Any +from minference.dist_ops.op_utils.xattn_utils import xattn_estimate +from minference.ops.minference_attn import block_attn_fwd, block_attn_bwd @triton.jit def softmax_fuse_block_sum_kernel_causal( @@ -366,3 +369,96 @@ def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, ) return output + +class XAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_indices, + xattn_params, # Dict[str, Any] + granularity, + causal, + softmax_scale, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + q_block_num = (q.shape[1] + granularity - 1) // granularity + # (batch_size, head_num, q_block_num, q_block_num) + _, block_mask = xattn_estimate( + q.transpose(1, 2), k.transpose(1, 2), + granularity, + **xattn_params + ) + block_mask = block_mask[:, :, -q_block_num:, -q_block_num:].contiguous() + + # Block Mask + out, softmax_lse = block_attn_fwd( + q, k, v, softmax_scale, + block_mask, + granularity=granularity, + causal=causal, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.head_indices = head_indices + + # print(f"{__name__} | out shape: {out.shape}") + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask = ctx.saved_tensors + causal = ctx.causal + + # Block Mask + dq, dk, dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, ctx.softmax_scale, + block_mask, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + causal=causal, + ) + return dq, dk, dv, None, None, None, None, None, None, None + +def xattn_flash_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + head_indices: List[int], # [num_qo_heads] + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + + return XAttnFunc.apply( + q, k, v, + head_indices, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + ) diff --git a/mtraining/attn_funcs/__init__.py b/mtraining/attn_funcs/__init__.py new file mode 100644 index 0000000..60ce388 --- /dev/null +++ b/mtraining/attn_funcs/__init__.py @@ -0,0 +1,36 @@ +from typing import Dict, Callable + +from .dense_func import fa_attn_forward, stripe_ring_attention_forward, zigzag_ring_attention_forward +from .minfer_func import minfer_attention_forward +from .moba_func import moba_attention_forward +from .xattn_func import xattn_attention_forward + + +class AttnType: + BASELINE: str = "baseline" + ZIGZAG_RING: str = "zigzag_ring" + STRIPE_RING: str = "stripe_ring" + + MINFER: str = "minfer" + MOBA: str = "moba" + XATTN: str = "xattn" + +ATTN_TO_FUNC = { + AttnType.BASELINE: fa_attn_forward, + AttnType.ZIGZAG_RING: zigzag_ring_attention_forward, + AttnType.STRIPE_RING: stripe_ring_attention_forward, + + AttnType.MINFER: minfer_attention_forward, + AttnType.MOBA: moba_attention_forward, + AttnType.XATTN: xattn_attention_forward, +} + +def overwrite_attn_implementation( + attn_dict: Dict[str, Callable], + attn_type: AttnType, +): + attn_func: Callable = ATTN_TO_FUNC[attn_type] + print(f"Overwriting attention implementation to {attn_type} ({attn_func.__name__})") + + for attn_name in attn_dict: + attn_dict[attn_name] = attn_func \ No newline at end of file diff --git a/mtraining/attn_funcs/dense_func.py b/mtraining/attn_funcs/dense_func.py new file mode 100644 index 0000000..9c8df79 --- /dev/null +++ b/mtraining/attn_funcs/dense_func.py @@ -0,0 +1,278 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This file modifies the official modeling_llama.py file at runtime to +# 1. register the flash attention function to nnscaler and update related code +# 2. replace the un-fused RMSNorm with apex's fused version +import torch +from torch import Tensor +from transformers.utils import is_flash_attn_2_available +from typing import List, Optional, Tuple, Union, Any, Dict +if is_flash_attn_2_available(): + from flash_attn.bert_padding import pad_input + from flash_attn import flash_attn_func, flash_attn_varlen_func + +from nnscaler.ir import IRTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op + +from minference.dist_ops.zigzag_attention import zigzag_ring_flash_attn_func +from minference.dist_ops.striped_attention import stripe_flash_attn_func + +from .utils import nnscaler_upad_input + + +def fa_attn_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + causal=True, + ): + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = nnscaler_upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + +# --------------------------------------------------------------------------- +def zigzag_ring_attention_forward( + module: torch.nn.Module, + query: Tensor, # [B, H, N, D] + key: Tensor, + value: Tensor, + attention_mask: Optional[Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[Tensor, None]: + return wrap_zigzag_attn_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + layer_idx=module.layer_idx, + softmax_scale=scaling, + dropout_p=dropout, + causal=True, + ), None + + +def wrap_zigzag_attn_func( + q: Tensor, k: Tensor, v: Tensor, + layer_idx: int, + softmax_scale: Tensor=None, + dropout_p: float=0.0, + causal: bool=True, + window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None +) -> Tensor: + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert causal == True, "zigzag_ring is meaningless for causal=False" + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + output = zigzag_ring_flash_attn_func( + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + return output + +# --------------------------------------------------------------------------- +def stripe_ring_attention_forward( + module: torch.nn.Module, + query: Tensor, # [B, H, N, D] + key: Tensor, + value: Tensor, + attention_mask: Optional[Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[Tensor, None]: + return wrap_striped_attn_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + layer_idx=module.layer_idx, + softmax_scale=scaling, + dropout_p=dropout, + causal=True, + ), None + + +def wrap_striped_attn_func( + q: Tensor, k: Tensor, v: Tensor, layer_idx: int, + granularity: int=1, + softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None + ) -> Tensor: + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + output = stripe_flash_attn_func( + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + return output + + + +# --------------------------------------------------------------------------- +def flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + if isinstance(attention_mask, IRTensor): + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^' + else: + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + + +def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +def ring_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno)(fa_attn_forward) +register_op(ring_attn_anno, emit_fn=emit_ring)(wrap_zigzag_attn_func) +register_op(ring_attn_anno, emit_fn=emit_ring)(wrap_striped_attn_func) diff --git a/mtraining/attn_funcs/minfer_func.py b/mtraining/attn_funcs/minfer_func.py new file mode 100644 index 0000000..451c239 --- /dev/null +++ b/mtraining/attn_funcs/minfer_func.py @@ -0,0 +1,412 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This file modifies the official modeling_llama.py file at runtime to +# 1. register the flash attention function to nnscaler and update related code +# 2. replace the un-fused RMSNorm with apex's fused version +import json +import torch +import logging +logger = logging.getLogger(__name__) + +from typing import List, Optional, Tuple, Dict, Callable +from transformers.utils import logging, is_flash_attn_2_available +if is_flash_attn_2_available(): from flash_attn import flash_attn_func + +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor +from nnscaler.ir.operator import IRFwOperation + +from minference.ops.minference_attn import minference_flash_attn_func +from minference.ops.minference_attn_triton import minference_flash_attn_triton_func +from minference.dist_ops import ( + minfer_stripe_func, minfer_stripe_triton_func, + minfer_zigzag_func, minfer_dr_stripe_func, minfer_dr_stripe_triton_func, +) + + +# ======================================================= +def minfer_op( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + group: Optional[torch.distributed.ProcessGroup] = None, +): + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + if torch.version.hip is None: + attn_output = minference_flash_attn_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) + else: + attn_output = minference_flash_attn_triton_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) + return attn_output.contiguous() + + +def minfer_stripe_op( + query_states: torch.Tensor, # [batch_size, num_heads, num_tokens, head_dim] + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + process_group: Optional[torch.distributed.ProcessGroup] = None, +): + if (process_group is None or len(process_group) == 1): + softmax_scale = query_states.shape[-1] ** (-0.5) + + output = flash_attn_func( + query_states.transpose(1, 2), + key_states.transpose(1, 2), + value_states.transpose(1, 2), + attn_dropout, softmax_scale, causal=True) + return output + group = DeviceGroup().get_group(process_group) + + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + if torch.version.hip is None: + attn_output = minfer_stripe_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + else: + attn_output = minfer_stripe_triton_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) + return attn_output.contiguous() + + +def minfer_zigzag_op( + query_states: torch.Tensor, # [batch_size, num_heads, num_tokens, head_dim] + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + process_group: Optional[torch.distributed.ProcessGroup] = None, +): + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + softmax_scale = query_states.shape[-1] ** (-0.5) + output = flash_attn_func( + query_states.transpose(1, 2), + key_states.transpose(1, 2), + value_states.transpose(1, 2), + attn_dropout, softmax_scale, causal=True) + return output + group = DeviceGroup().get_group(process_group) + + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + if torch.version.hip is None: + attn_output = minfer_zigzag_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + else: + raise NotImplementedError("Triton-only version is not implemented for MInfer w. zigzag") + return attn_output.contiguous() + +def minfer_dr_stripe_op( + query_states: torch.Tensor, # [batch_size, num_heads, num_tokens, head_dim] + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + process_group: Optional[torch.distributed.ProcessGroup] = None, +): + if (process_group is None or len(process_group) == 1): + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + softmax_scale = query_states.shape[-1] ** (-0.5) + + output = flash_attn_func( + query_states.transpose(1, 2), + key_states.transpose(1, 2), + value_states.transpose(1, 2), + attn_dropout, softmax_scale, causal=True) + return output + + group = DeviceGroup().get_group(process_group) + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + + if torch.version.hip is None: + attn_output = minfer_dr_stripe_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + else: + attn_output = minfer_dr_stripe_triton_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) + + return attn_output.contiguous() + + + +MINFER_IMPLEMENTATIONS: Dict[str, Callable] = { + "default": minfer_op, + "stripe": minfer_stripe_op, + "zigzag": minfer_zigzag_op, + "dr_stripe": minfer_dr_stripe_op, +} + +def emit_minfer_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: + # partition on num_head dim + kw_pairs.append("process_group=None") + elif partition_dims[0] == 2: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +def minfer_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[1] != key_states.shape[1]: + assert query_states.shape[1] % key_states.shape[1] == 0 + group_size = query_states.shape[1] // key_states.shape[1] + assert query_states.shape[1] == value_states.shape[1] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b {q_anno} l^ hd^, b {kv_anno} s^ hd^, b {kv_anno} s^ vd^, {q_anno} -> b l^ {q_anno} vd^' + +def minfer_attn_ring_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[1] != key_states.shape[1]: + assert query_states.shape[1] % key_states.shape[1] == 0 + group_size = query_states.shape[1] // key_states.shape[1] + assert query_states.shape[1] == value_states.shape[1] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b {q_anno} l hd^, b {kv_anno} l hd^, b {kv_anno} l vd^, {q_anno} -> b l {q_anno} vd^' + +if __name__ != "__main__": + register_op(minfer_attn_anno)(minfer_op) + register_op(minfer_attn_ring_anno, emit_fn=emit_minfer_ring)(minfer_stripe_op) + register_op(minfer_attn_ring_anno, emit_fn=emit_minfer_ring)(minfer_zigzag_op) + register_op(minfer_attn_ring_anno, emit_fn=emit_minfer_ring)(minfer_dr_stripe_op) + +class MInferAttnFunc: + def __init__(self): + self.initialized = False + + def init_minfer_params( + self, + config_path: str, + minfer_implementation: str, # "fa", "stripe", "zigzag" + granularity: int = 128, + ): + assert minfer_implementation in MINFER_IMPLEMENTATIONS, f"minfer_implementation should be one of {MINFER_IMPLEMENTATIONS}, but got {self.minfer_implementation}" + self.minfer_implementation: str = minfer_implementation + + self.config_path = config_path + self.all_pattern_dict = json.load(open(self.config_path)) + self.granularity = granularity + + self.initialized = True + + def get_pattern_dict(self, layer_idx): + return {int(ii): jj for ii, jj in self.all_pattern_dict[layer_idx].items()} + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + attn_module_config: Dict[str, int], + attn_dropout: float=0.0, + ): + bsz, q_len = query_states.shape[0], query_states.shape[2] + head_dim, layer_idx = attn_module_config["head_dim"], attn_module_config["layer_idx"] + + pattern_dict = self.get_pattern_dict(layer_idx) + minfer_args = ( + query_states, key_states, value_states, + head_indices, + bsz, q_len, head_dim, layer_idx, + pattern_dict, attn_dropout, self.granularity, + ) + + if self.minfer_implementation == "default": + return minfer_op(*minfer_args) + elif self.minfer_implementation == "stripe": + return minfer_stripe_op(*minfer_args) + elif self.minfer_implementation == "zigzag": + return minfer_zigzag_op(*minfer_args) + elif self.minfer_implementation == "dr_stripe": + return minfer_dr_stripe_op(*minfer_args) + else: + raise ValueError(f"Unsupported minfer_implementation: {self.minfer_implementation}") + +def minfer_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, H, N, D] + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + attn_module_config = { + "num_heads": module.config.num_attention_heads, + "head_dim": module.head_dim, + "layer_idx": module.layer_idx, + } + head_indices = torch.arange(attn_module_config["num_heads"], device=query.device, dtype=torch.int32) + + return module.minfer_attn_func.forward( + query, key, value, head_indices, + attn_module_config, + dropout, + ), None diff --git a/mtraining/attn_funcs/moba_func.py b/mtraining/attn_funcs/moba_func.py new file mode 100644 index 0000000..41a5993 --- /dev/null +++ b/mtraining/attn_funcs/moba_func.py @@ -0,0 +1,200 @@ +import os +import yaml +import torch +import torch.distributed as dist + +from torch import Tensor +from typing import Tuple, Optional, List, Dict, Any + +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op + +from minference.ops.moba import moba_attn_varlen, moba_layer +from minference.dist_ops.moba_zigzag import moba_zigzag_func +from minference.dist_ops.op_utils.moba_utils import MoBAConfig + +def load_moba_config(moba_config_dict: Dict[str, Any]): + moba_config = MoBAConfig(**moba_config_dict) + return moba_config + +def moba_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, H, N, D] + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + moba_topk, moba_chunk_size = module.moba_topk, module.moba_chunk_size + implementation = module.implementation + if implementation == "default": + return wrapped_moba_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + moba_topk, moba_chunk_size, + attention_mask, dropout, scaling, sliding_window, softcap, **kwargs + ), None + else: + seq_len = query.shape[2] + layer_idx = module.layer_idx + return wrapped_moba_zigzag_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + seq_len, + moba_topk, moba_chunk_size, + layer_idx, + attention_mask, dropout, scaling, sliding_window, softcap, + ), None + +# ------------------------------------------ +def wrapped_moba_func( + q: Tensor, k: Tensor, v: Tensor, + moba_topk: int, moba_chunk_size: int, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +): + return moba_layer( + moba_attn_varlen, + moba_chunk_size, moba_topk, + q, k, v, + attention_mask, dropout, scaling, sliding_window, softcap, + **kwargs + ) + +def wrapped_moba_zigzag_func( + query: Tensor, # [B, N, H, D] + key: Tensor, + value: Tensor, + seq_len: int, + moba_topk: int, moba_chunk_size: int, + layer_idx: int, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + process_group: Tuple[int]=None, +): + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + from flash_attn import flash_attn_func + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + output = flash_attn_func(query, key, value, 0.0, softmax_scale, True) + return output + + batch, block_seq_len, q_heads, head_dim = query.shape + assert batch == 1, "Current implementation only supports batch size = 1" + + # [0, BLK_SZ, 2 * BLK_SZ, 3 * BLK_SZ, ..., seq_len - 1] + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_offsets = torch.arange(0, seq_len, seq_len // world_size)[rank:rank+1] + + _, _, kv_heads, _ = key.shape + + query = query.reshape(-1, q_heads, head_dim) # [B * N, H, D] + key = key.reshape(-1, kv_heads, head_dim) + value = value.reshape(-1, kv_heads, head_dim) + + # Assume only one batch or all batches have the same length + cu_seqlens = torch.cumsum( + torch.tensor([0] + [seq_len] * batch, device=query.device), + dim=0, + dtype=torch.int32, + ) + + local_process_group = DeviceGroup().get_group(process_group) + output = moba_zigzag_func( + query, key, value, + seq_offsets, + layer_idx, + cu_seqlens, + moba_chunk_size, + moba_topk, + dropout, softmax_scale, + True, # causal, + (-1, -1), # window_size, + None, # alibi_slopes, + False, # deterministic, + False, # return_softmax, + local_process_group, # group + ).contiguous() + return output.view(batch, block_seq_len, q_heads, head_dim) + + +# -------------------------------------------------- +def moba_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[1] != key_states.shape[1]: + assert query_states.shape[1] % key_states.shape[1] == 0 + group_size = query_states.shape[1] // key_states.shape[1] + assert query_states.shape[1] == value_states.shape[1] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b {q_anno} l^ hd^, b {kv_anno} s^ hd^, b {kv_anno} s^ vd^, {q_anno} -> b l^ {q_anno} vd^' + +def moba_zigzag_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + num_q_heads, num_kv_heads = query_states.shape[2], key_states.shape[2] + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + group_size = num_q_heads // num_kv_heads + assert num_q_heads == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + attn_anno = f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + return attn_anno + +def emit_moba_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +if __name__ != "__main__": + register_op(moba_attn_anno)(wrapped_moba_func) + register_op(moba_zigzag_attn_anno, emit_fn=emit_moba_zigzag)(wrapped_moba_zigzag_func) diff --git a/mtraining/attn_funcs/utils.py b/mtraining/attn_funcs/utils.py new file mode 100644 index 0000000..f01fc63 --- /dev/null +++ b/mtraining/attn_funcs/utils.py @@ -0,0 +1,49 @@ +import torch +import logging +logger = logging.getLogger(__name__) + +from transformers.utils import is_flash_attn_2_available, logging +from transformers.modeling_flash_attention_utils import _get_unpad_data +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, unpad_input # noqa + + + +def nnscaler_upad_input(query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + _, _, num_heads, _ = query_layer.shape + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) \ No newline at end of file diff --git a/mtraining/attn_funcs/xattn_func.py b/mtraining/attn_funcs/xattn_func.py new file mode 100644 index 0000000..abbf854 --- /dev/null +++ b/mtraining/attn_funcs/xattn_func.py @@ -0,0 +1,217 @@ +import os +import yaml +import copy +import torch +from torch import Tensor +import torch.distributed as dist + +from functools import partial +from flash_attn import flash_attn_func +from typing import Tuple, Optional, Dict, Any, List + +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor +from nnscaler.ir.operator import IRFwOperation + +from minference.ops.xattention_fa import xattn_flash_attn_func +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func + +def xattn_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, H, N, D] + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + granularity, xattn_params = module.granularity, module.xattn_params + implementation = module.implementation + layer_idx = module.layer_idx + + if implementation == "default": + head_indices = torch.arange(module.config.num_attention_heads, device=query.device) + return wrapped_xattn_func_( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + head_indices, + granularity, xattn_params, + dropout, scaling, sliding_window + ) + elif implementation == "zigzag": + return wrapped_xattn_zigzag_func_( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + layer_idx, + granularity, xattn_params, + dropout=dropout, scaling=scaling, sliding_window=sliding_window, + ) + else: + raise NotImplementedError(f"Unsupported implementation for xattn_attention_forward: {implementation}") + +# ------------------------------------------ +# Non-CP version +def wrapped_xattn_func_( + q: Tensor, k: Tensor, v: Tensor, # [B, N, H, D] + head_indices: torch.Tensor, + granularity: int, + xattn_params: Dict[str, Any], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, +): + return wrapped_xattn_func( + q, k, v, + head_indices, + granularity, xattn_params, + dropout, scaling, sliding_window + ), None + +def wrapped_xattn_func( + q: Tensor, k: Tensor, v: Tensor, + head_indices: torch.Tensor, + granularity: int, + xattn_params: Dict[str, Any], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, +): + sliding_window = -1 if sliding_window is None else sliding_window + return xattn_flash_attn_func( + q, k, v, + head_indices.detach().cpu().numpy().tolist(), + xattn_params, + granularity, + dropout_p=dropout, + softmax_scale=scaling, + causal=True, + window_size=(sliding_window, sliding_window), + alibi_slopes=None, + deterministic=False, + ) + + + + +# ------------------------------------------ +# Zigzag Version +def wrapped_xattn_zigzag_func_( + q: Tensor, k: Tensor, v: Tensor, # [B, N, H, D] + layer_idx: int, + granularity: int, + xattn_params: Dict[str, Any], + causal: bool=True, + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + process_group: Optional[dist.ProcessGroup] = None, +): + return wrapped_xattn_zigzag_func( + q, k, v, + layer_idx, + granularity, xattn_params, + causal=causal, + dropout=dropout, scaling=scaling, sliding_window=sliding_window, + ), None + + +def wrapped_xattn_zigzag_func( + q: Tensor, k: Tensor, v: Tensor, # [B, N, H, D] + layer_idx: int, + granularity: int, + xattn_params: Dict[str, Any], + causal: bool=True, + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + process_group: Tuple[int]=None, +): + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `scaling`, which is equivalent + # to the behavior of the original flash_attn_func. + if scaling is None: + scaling = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, scaling, causal) + return output + + group = DeviceGroup().get_group(process_group) + + xattn_params = copy.copy(xattn_params) + xattn_params.pop("chunk_size", None) + return xattn_zigzag_func( + q, k, v, + layer_idx, + xattn_params, + granularity, + dropout_p=dropout, + softmax_scale=scaling, + causal=causal, + group=group, + ).contiguous() + +def xattn_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, {q_anno} -> b l^ {q_anno} vd^' + +def emit_xattn_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +def xattn_zigzag_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + +if __name__ != "__main__": + register_op(xattn_attn_anno)(wrapped_xattn_func) + register_op(xattn_zigzag_anno, emit_fn=emit_xattn_zigzag)(wrapped_xattn_zigzag_func) diff --git a/mtraining/train.py b/mtraining/train.py index 8993d04..8590def 100644 --- a/mtraining/train.py +++ b/mtraining/train.py @@ -3,17 +3,15 @@ import os import yaml import torch -import shutil +import logging import argparse import numpy as np -import torch.distributed as dist -from typing import Dict, List, Optional from datasets import load_from_disk -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling +from typing import Dict, List, Optional from transformers.modeling_utils import PreTrainedModel +from transformers import AutoConfig, DataCollatorForLanguageModeling -from nnscaler.utils import set_default_logger_level from nnscaler.cli.trainer_args import ( CheckpointConfig, DatasetConfig, @@ -25,45 +23,40 @@ DatasetSamplerConfig, ) from nnscaler.parallel import ComputeConfig +from nnscaler.utils import set_default_logger_level from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW from nnscaler.cli.loggers.tensorboard import TensorBoardLogger +from minference.models_patch import MInference +from minference.minference_configuration import MInferenceConfig +from minference.configs.model2path import BASE_DIR as SPARSE_PATTERN_CONFIG_DIR + +from .attn_funcs import AttnType, overwrite_attn_implementation from .trainer import CustomTrainer as Trainer, CustomTrainerArgs as TrainerArgs +from .models import MODEL_TO_ATTN_FUNC, MODEL_ID_TO_MODEL_CLS, MODEL_ID_TO_PREFIX + +from .utils.expr_data import update_expr_data +from .utils.general import freeze_model_params, load_comm_profile_data from .utils import chunk_linear_cross_entropy, get_tokenizer, aggregate_outputs_fn, get_resume_path -from MTraining.ops.minfer import ExprMInferConfig as MInferenceConfig, ExprMInference as MInference -from MTraining.ops import AttnType, overwrite_attn_implementation, load_moba_config, MoBAConfig -from MTraining.models import MODEL_TO_ATTN_FUNC, MODEL_ID_TO_MODEL_CLS, MODEL_ID_TO_PREFIX -from .utils.paths import ( - MINFER_CONFIG_DIR, SPARSE_PATTERN_CONFIG_DIR, SPARSE_HEAD_MAP_DIR, - update_expr_data_save_path -) -from MTraining.utils.train_utils import freeze_model_params, load_comm_profile_data -from MTraining.utils.expr_data import update_expr_data +from .utils.paths import TRAIN_ATTN_CONFIG_DIR, update_expr_data_save_path -import logging +IGNORE_IDX = -100 logger = logging.getLogger(__name__) set_default_logger_level('INFO') -IGNORE_IDX = -100 - - def init_by_attn_type(model_id: str, attn_type: AttnType): attn_dict = MODEL_TO_ATTN_FUNC[model_id] if attn_type == AttnType.BASELINE: print(f"{__name__} | Using Baseline Model...") - elif attn_type == AttnType.FLEX_PREFILL: - print(f"{__name__} | Using FlexPrefill-equipped Model ...") - elif attn_type == AttnType.RING_ATTN: - print(f"{__name__} | Using Ring Attention Zigzag-equipped Model ...") - elif attn_type == AttnType.RING_ATTN_STRIPE: - print(f"{__name__} | Using Ring Attention Stripe-equipped Model ...") - elif attn_type == AttnType.MF_MB: + elif attn_type == AttnType.ZIGZAG_RING: + print(f"{__name__} | Using Ring Zigzag Attention-equipped Model ...") + elif attn_type == AttnType.STRIPE_RING: + print(f"{__name__} | Using Ring Stripe Attention-equipped Model ...") + elif attn_type == AttnType.MINFER: print(f"{__name__} | Using MInference-equipped Model ...") elif attn_type == AttnType.MOBA: print(f"{__name__} | Using MoBA-equipped Model ...") - elif attn_type == AttnType.ZIGZAG_MOBA: - print(f"{__name__} | Using ZigZag MoBA-equipped Model ...") elif attn_type == AttnType.XATTN: print(f"{__name__} | Using XAttention-equipped Model ...") else: @@ -128,33 +121,13 @@ def __init__( **kwargs, ) - # -------------------------------------------- - # MInference implementation: "fa", "stripe" - minfer_implementation: str = minfer_config.pop('implementation', 'fa') - - # -------------------------------------------- - # Sparse iteratio, layer and head control - start_sparse_iter: int = minfer_config.pop('start_sparse_iter', 0) - start_sparse_layer: int = minfer_config.pop('start_sparse_layer', 0) - adaptive_sparse: bool = minfer_config.pop('adaptive_sparse', False) - sparse_head_map_name: str = minfer_config.pop('sparse_head_map_name', None) - if sparse_head_map_name is not None: - active_sparse_map_path: str = os.path.join( - SPARSE_HEAD_MAP_DIR, - f'{sparse_head_map_name}.npy', - ) - print(f"{__name__} | Active Sparse Head Map Path: {active_sparse_map_path}") - active_sparse_map: np.ndarray = np.load(active_sparse_map_path) - active_sparse_map: List[List[bool]] = active_sparse_map.tolist() - else: - active_sparse_map: List[List[bool]] = None - # ---------------------------------------------- # Ring Attention specific granularity: int = minfer_config.pop('granularity', 128) # -------------------------------------------- - # Standard MInference Setup + # MInference Setup + minfer_implementation: str = minfer_config.pop('implementation', 'default') minfer_attn_type = minfer_config.pop('attn_type', 'minference') minfer_config['config_path'] = os.path.join( SPARSE_PATTERN_CONFIG_DIR, @@ -172,7 +145,7 @@ def __init__( # We still need to attach the function object to the model # otherwise the states of the function will be lost as nnscaler will only load the model from file # but not call this procedure again - from ops.minfer_func import MInferAttnFunc + from .attn_funcs.minfer_func import MInferAttnFunc Attention = self.model.model.layers[0].self_attn.__class__ def update_module(m): if isinstance(m, Attention): @@ -180,37 +153,10 @@ def update_module(m): m.minfer_attn_func.init_minfer_params( config_path=minfer_config.config_path, minfer_implementation=minfer_implementation, - - start_sparse_iter=start_sparse_iter, - start_sparse_layer=start_sparse_layer, - adaptive_sparse=adaptive_sparse, - active_sparse_map=active_sparse_map, - granularity=granularity, ) self.model.apply(update_module) -class FlexPrefillModel(BaselineModel): - def __init__( - self, - model_id, - config_path: str=None, - attn_config: Dict={}, - **kwargs, - ): - super().__init__( - model_id=model_id, - config_path=config_path, - **kwargs, - ) - - from ops.flex_prefill_func import FlexPrefillFunc - Attention = self.model.model.layers[0].self_attn.__class__ - def update_module(m): - if isinstance(m, Attention): - m.flex_prefill_attn_func = FlexPrefillFunc(attn_config) - self.model.apply(update_module) - class XAttnModel(BaselineModel): def __init__( self, @@ -252,6 +198,7 @@ def __init__( config_path=config_path, **kwargs, ) + from minference.dist_ops.op_utils.moba_utils import MoBAConfig # -------------------------------------------- print(f"MoBAConfig: {moba_config_dict}") @@ -271,46 +218,44 @@ def update_module(m): ATTN_TO_MODEL = { AttnType.BASELINE: BaselineModel, - AttnType.FLEX_PREFILL: FlexPrefillModel, - AttnType.MF_MB: MInferModel, - AttnType.RING_ATTN: BaselineModel, - AttnType.RING_ATTN_STRIPE: BaselineModel, + AttnType.STRIPE_RING: BaselineModel, + AttnType.ZIGZAG_RING: BaselineModel, + + AttnType.MINFER: MInferModel, AttnType.MOBA: MoBAModel, - AttnType.ZIGZAG_MOBA: MoBAModel, AttnType.XATTN: XAttnModel, } -def load_minfer_config(minfer_config_name: str) -> MInferenceConfig: - minfer_config_path = os.path.join(MINFER_CONFIG_DIR, f'{minfer_config_name}.yaml') - if not os.path.exists(minfer_config_path): - print(f"{__name__} | MInference config {minfer_config_name} not found in {minfer_config_path}. Use empty minfer config") - minfer_config = {} +def load_train_attn_config(train_attn_config_name: str) -> MInferenceConfig: + train_attn_config_path = os.path.join(TRAIN_ATTN_CONFIG_DIR, f'{train_attn_config_name}.yaml') + if not os.path.exists(train_attn_config_path): + print(f"{__name__} | MInference config {train_attn_config_name} not found in {train_attn_config_path}. Use empty minfer config") + train_attn_config = {} else: - print(f"{__name__} | MInference config {minfer_config_name} found in {minfer_config_path}") - with open(minfer_config_path, 'r') as f: - minfer_config = yaml.safe_load(f) + print(f"{__name__} | MInference config {train_attn_config_name} found in {train_attn_config_path}") + with open(train_attn_config_path, 'r') as f: + train_attn_config = yaml.safe_load(f) print('-' * 20) - print("MInference Config:") - print(minfer_config) + print("Training Attention Config:") + print(train_attn_config) print('-' * 20) + return train_attn_config - return minfer_config - -def build_model_args(args, minfer_config: MInferenceConfig) -> Dict: +def build_model_args(args, train_attn_config: MInferenceConfig) -> Dict: model_args = { 'model_id': args.model_id, 'config_path': args.model_config_path, "active_param_config_name": args.active_param_config_name, } if args.attn_type == AttnType.MF_MB: - model_args['minfer_config'] = minfer_config + model_args['minfer_config'] = train_attn_config elif args.attn_type == AttnType.FLEX_PREFILL: - model_args['attn_config'] = minfer_config + model_args['attn_config'] = train_attn_config elif args.attn_type == AttnType.XATTN: - model_args['xattn_params'] = minfer_config + model_args['xattn_params'] = train_attn_config elif args.attn_type == AttnType.MOBA or args.attn_type == AttnType.ZIGZAG_MOBA: - model_args['moba_config_dict'] = minfer_config + model_args['moba_config_dict'] = train_attn_config return model_args @@ -324,10 +269,9 @@ def main(args): load_comm_profile_data(args) init_by_attn_type(args.model_id, args.attn_type) - minfer_config = load_minfer_config(args.minfer_config_name) - - # broadcast_strategy = 'all' if args.run_mode == 'run' else 'none' + train_attn_config = load_train_attn_config(args.train_attn_config_name) broadcast_strategy = 'all' + # --------------------------------- # Compute config if args.run_mode == 'compile': @@ -361,7 +305,6 @@ def main(args): constant_folding=True, use_zero=True, use_end2end=True, - # autodist config: pas_config=pas_config, ) @@ -411,7 +354,7 @@ def collate_fn(samples): # --------------------------------- # Model Config - model_args = build_model_args(args, minfer_config) + model_args = build_model_args(args, train_attn_config) model_config = ModelConfig( type=ATTN_TO_MODEL[args.attn_type], args=model_args, @@ -527,7 +470,7 @@ def print_args(args: argparse.Namespace): print('-' * 40) print(f"Model Config Path:\t{args.model_config_path}") print(f"Dataset path:\t{args.dataset_path}") - print(f'MInferenece Config Name:\t{args.minfer_config_name}') + print(f'Training Attention Config Name:\t{args.train_attn_config_name}') print(f"Compile Save Path:\t{args.compile_save_path}") print(f"Attention Save Path:\t{args.attn_save_path}") print(f"Tensorboard Log Path:\t{args.tf_log_dir}") @@ -577,7 +520,7 @@ def print_args(args: argparse.Namespace): parser.add_argument('--model_config_path', type=str, default=None, help='path to the model config') parser.add_argument('-s', '--attn_save_step', type=int, default=1, help='Save attention data every n steps') - parser.add_argument('--minfer_config_name', type=str, default=None, help='Name of Minference config file') + parser.add_argument('--train_attn_config_name', type=str, default=None, help='Name of Minference config file') parser.add_argument('--compile_save_path', type=str, default='./.nnscaler', help='path to save compiled code') parser.add_argument('--attn_save_path', type=str, default=None, help='path to save attention data') parser.add_argument('--tf_log_dir', type=str, default=None, help='path to save tensorboard logs') @@ -610,7 +553,7 @@ def print_args(args: argparse.Namespace): if args.n_iter <= 0: args.n_iter = None if args.n_epochs <= 0: args.n_epochs = None - if args.minfer_config_name is None or args.minfer_config_name.lower() == 'none': args.minfer_config_name = None + if args.train_attn_config_name is None or args.train_attn_config_name.lower() == 'none': args.train_attn_config_name = None if args.transfer_config_dir.lower() == 'none': args.transfer_config_dir = None if args.active_param_config_name.lower() == 'none': args.active_param_config_name = None @@ -620,6 +563,5 @@ def print_args(args: argparse.Namespace): args.check_resume, args.resume_from, args.ckpt_save_dir, args.runtime_ngpus ) - print_args(args) main(args) \ No newline at end of file diff --git a/mtraining/train_attn_configs/moba_256k_s95.yaml b/mtraining/train_attn_configs/moba_256k_s95.yaml new file mode 100644 index 0000000..712cae5 --- /dev/null +++ b/mtraining/train_attn_configs/moba_256k_s95.yaml @@ -0,0 +1,2 @@ +moba_chunk_size: 4096 +moba_topk: 6 \ No newline at end of file diff --git a/mtraining/train_attn_configs/moba_512k_s95.yaml b/mtraining/train_attn_configs/moba_512k_s95.yaml new file mode 100644 index 0000000..348f394 --- /dev/null +++ b/mtraining/train_attn_configs/moba_512k_s95.yaml @@ -0,0 +1,2 @@ +moba_chunk_size: 4096 +moba_topk: 12 \ No newline at end of file diff --git a/mtraining/train_attn_configs/qwen_flex_090.yaml b/mtraining/train_attn_configs/qwen_flex_090.yaml new file mode 100644 index 0000000..f584796 --- /dev/null +++ b/mtraining/train_attn_configs/qwen_flex_090.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_flex_0.90 +implementation: stripe \ No newline at end of file diff --git a/mtraining/train_attn_configs/qwen_flex_095.yaml b/mtraining/train_attn_configs/qwen_flex_095.yaml new file mode 100644 index 0000000..b3685e6 --- /dev/null +++ b/mtraining/train_attn_configs/qwen_flex_095.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_flex_0.95 +implementation: stripe \ No newline at end of file diff --git a/mtraining/train_attn_configs/qwen_mf_dr_stripe.yaml b/mtraining/train_attn_configs/qwen_mf_dr_stripe.yaml new file mode 100644 index 0000000..b0c2051 --- /dev/null +++ b/mtraining/train_attn_configs/qwen_mf_dr_stripe.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_kv_out_v32_fit_o_best_pattern +implementation: dr_stripe \ No newline at end of file diff --git a/mtraining/train_attn_configs/qwen_mf_stripe.yaml b/mtraining/train_attn_configs/qwen_mf_stripe.yaml new file mode 100644 index 0000000..3601afd --- /dev/null +++ b/mtraining/train_attn_configs/qwen_mf_stripe.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_kv_out_v32_fit_o_best_pattern +implementation: stripe \ No newline at end of file diff --git a/mtraining/train_attn_configs/qwen_mf_zigzag.yaml b/mtraining/train_attn_configs/qwen_mf_zigzag.yaml new file mode 100644 index 0000000..a6afdaa --- /dev/null +++ b/mtraining/train_attn_configs/qwen_mf_zigzag.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_kv_out_v32_fit_o_best_pattern +implementation: zigzag \ No newline at end of file diff --git a/mtraining/train_attn_configs/xattn_default.yaml b/mtraining/train_attn_configs/xattn_default.yaml new file mode 100644 index 0000000..6cf6980 --- /dev/null +++ b/mtraining/train_attn_configs/xattn_default.yaml @@ -0,0 +1,12 @@ +granularity: 128 +stride: 16 +norm: 1 +softmax: true +threshold: 0.9 +chunk_size: 16384 +select_mode: inverse +use_triton: true +causal: true +kdb: 1 +keep_sink: false +keep_recent: false diff --git a/mtraining/train_attn_configs/xattn_zigzag_s16.yaml b/mtraining/train_attn_configs/xattn_zigzag_s16.yaml new file mode 100644 index 0000000..52df90b --- /dev/null +++ b/mtraining/train_attn_configs/xattn_zigzag_s16.yaml @@ -0,0 +1,12 @@ +implementation: zigzag +granularity: 128 +stride: 16 +norm: 1 +softmax: true +threshold: 0.9 +select_mode: inverse +use_triton: true +causal: true +kdb: 1 +keep_sink: false +keep_recent: false diff --git a/mtraining/train_attn_configs/xattn_zigzag_s16_t85.yaml b/mtraining/train_attn_configs/xattn_zigzag_s16_t85.yaml new file mode 100644 index 0000000..d420943 --- /dev/null +++ b/mtraining/train_attn_configs/xattn_zigzag_s16_t85.yaml @@ -0,0 +1,12 @@ +implementation: zigzag +granularity: 128 +stride: 16 +norm: 1 +softmax: true +threshold: 0.85 +select_mode: inverse +use_triton: true +causal: true +kdb: 1 +keep_sink: false +keep_recent: false diff --git a/mtraining/trainer.py b/mtraining/trainer.py index 9fa5dd2..ea45b64 100644 --- a/mtraining/trainer.py +++ b/mtraining/trainer.py @@ -24,54 +24,17 @@ Trainer, _StepStat, TrainerArgs, TrainStatus, AggregatedTrainHook, TrainHook ) -from MTraining.utils.custom_parallel import parallelize as custom_parallelize -from MTraining.utils.val_utils import fix_model_state_dict -from MTraining.utils.paths import EXPR_DATA_SAVE_PATH +from .utils.paths import EXPR_DATA_SAVE_PATH +from .utils.general import fix_model_state_dict +from .utils.custom_parallel import parallelize as custom_parallelize logger = logging.getLogger(__name__) -def save_qkv_start_iter_idx(): - if os.getenv("COLLECT_QKV_DATA", "0") == "1": - # check QKV save dir and return the maximum iter idex corresponding to which the complete set of QKV shards are saved - qkv_save_dir = os.path.join(os.getenv("QKV_STORE_DIR"), str(torch.distributed.get_world_size())) - print(f"Rank {torch.distributed.get_rank()} | {__name__} | qkv_save_dir={qkv_save_dir}") - if not os.path.exists(qkv_save_dir): - print(f"Rank {torch.distributed.get_rank()} | {__name__} | qkv_save_dir does not exist") - return -1 - - # check subdirectories in this save dir (each subdir is named as `sample_{iter_idx}`) - num_gpus, num_layers = int(os.getenv("ORIG_GPU_SET").split('_')[-1]), int(os.getenv("NUM_LAYERS")) - subdirs = [d for d in os.listdir(qkv_save_dir) if os.path.isdir(os.path.join(qkv_save_dir, d)) and d.startswith("sample_")] - print(f"Rank {torch.distributed.get_rank()} | {__name__} | subdirs={subdirs}") - for iter_idx in range(len(subdirs) - 1, -1, -1): - subdir = os.path.join(qkv_save_dir, f"sample_{iter_idx}") - - # check if all layers are saved in this subdir - layer_dirs = [os.path.join(subdir, d) for d in os.listdir(subdir) if d.startswith('layer_') and os.path.isdir(os.path.join(subdir, d))] - if len(layer_dirs) == num_layers: - # check if all GPUs are saved in this subdir - print(f"Rank {torch.distributed.get_rank()} | {__name__} | layer_dirs={layer_dirs}") - for layer_dir in layer_dirs: - shard_paths = [sp for sp in os.listdir(layer_dir) - if (sp.startswith('q_') or sp.startswith('k_') or sp.startswith('v_') or sp.startswith('dout_')) \ - and sp.endswith('.pt') \ - and os.path.isfile(os.path.join(layer_dir, sp))] - if len(shard_paths) == 4 * num_gpus: - # all shards are saved - return iter_idx - - return -1 - @dataclass class CustomTrainerArgs(TrainerArgs): transfer_config: Optional[Dict[str, Any]] = None merged_ckpt_path: Optional[str] = None - - -EXPR_NAME: str -PT_LOG_SAVE_DIR: str - ITERATOR_COUNTER = defaultdict(int) def get_iter_cnt(rank: int): global ITERATOR_COUNTER @@ -82,82 +45,6 @@ def get_iter_batch_idx(rank: int, iter_cnt: int): global ITER_BATCH_IDX_DICT return ITER_BATCH_IDX_DICT.get(rank, {}).get(iter_cnt, 0) -SAVE_ITERVAL = -1 -def need_save_data(rank: int): - global SAVE_ITERVAL - if SAVE_ITERVAL <= 0: return False - return get_iter_cnt(rank) % SAVE_ITERVAL == 0 - -EXECUTOR = ThreadPoolExecutor(max_workers=4) # Adjust max_workers as needed -def save_iter_losses(epoch_idx: int, iter_idx: int, losses: List[Any], latencies: Optional[List[Any]]): - if torch.distributed.get_rank() != 0: return - - loss_save_dir = os.path.join(EXPR_DATA_SAVE_PATH['base_path'], 'losses', f"epoch_{epoch_idx}") - os.makedirs(loss_save_dir, exist_ok=True) - loss_save_path = os.path.join(loss_save_dir, f'iter_{iter_idx}.csv') - print(f"Rank {torch.distributed.get_rank()} | {__name__} | Saving iter losses to {loss_save_path} ...") - - loss_dict = {} - for sample_idx in range(len(losses)): - loss_dict[sample_idx] = { - 'loss': losses[sample_idx][1].item(), - 'num_tokens': losses[sample_idx][2], - } - if latencies is not None: - loss_dict[sample_idx]['latency'] = latencies[sample_idx] - loss_df = pd.DataFrame.from_dict(loss_dict, orient='index') - loss_df.index.name = "Sample" - - loss_df.to_csv(loss_save_path) - -def prof_train_step( - model: ParallelModule, - rank: int, iter_idx: int, - samples: List[Any], - is_dummy_batch: Optional[List[bool]] = None, - scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, -) -> List[Any]: - global ITER_BATCH_IDX_DICT - model._warn_uninitialized_non_persistent_buffers(raise_error=True) - - if not model.compute_config.use_end2end: - raise RuntimeError("train_step() is only supported in end2end mode") - if is_dummy_batch and len(samples) != len(is_dummy_batch): - raise ValueError("The length of samples and is_dummy_batch should be the same") - - model._scale_loss(is_dummy_batch, scale_fn) - - # sync_grad will be done in _train_step - # so we never need to call it manually - model._sync_grad_required = False - sample_count = len(samples) - dataloader = microbatches(samples, cycle=False) - - outputs = [] - trace_path = os.path.join(PT_LOG_SAVE_DIR, f'iter_{iter_idx}', f'trace_{rank}.log') - os.makedirs(os.path.dirname(trace_path), exist_ok=True) - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), - record_shapes=True, - with_stack=True - ) as prof: - for idx in range(sample_count): - ITER_BATCH_IDX_DICT[rank][iter_idx] = idx - - sample_start_time = time.perf_counter() - with accum_mode(begin=(idx==0), end=(idx==sample_count-1)): - output = model._train_step(dataloader) - sample_time = time.perf_counter() - sample_start_time - - if rank == 0: - print(f"| {__name__} | rank={rank} | iter_idx={iter_idx}, batch_idx={idx}, loss={output[1]}, latency={sample_time:.4f}s") - - outputs.append(output) - prof.step() - return outputs - def custom_train_step( model: ParallelModule, rank: int, iter_idx: int, @@ -196,7 +83,6 @@ def custom_train_step( sample_count = len(samples) dataloader = microbatches(samples, cycle=False) - qkv_start_iter = save_qkv_start_iter_idx() if model.use_scheduler: if len(samples) != model.nmicros_per_scheduler_step: raise ValueError(f"Expected {model.nmicros_per_scheduler_step} samples, but got {sample_count}") @@ -208,7 +94,6 @@ def custom_train_step( latencies = [] for idx in range(sample_count): ITER_BATCH_IDX_DICT[rank][iter_idx] = idx - if idx <= qkv_start_iter: continue sample_start_time = time.perf_counter() with accum_mode(begin=(idx==0), end=(idx==sample_count-1)): @@ -229,12 +114,10 @@ def custom_train_step( class CustomTrainer(Trainer): def __init__( - self, - argv: Optional[List[str]] = None, - *, - train_args: Optional[Union[Dict[str, Any], CustomTrainerArgs]] = None, - save_data_steps: int = 1, - enable_prof: bool = False, + self, + argv: Optional[List[str]] = None, + *, + train_args: Optional[Union[Dict[str, Any], CustomTrainerArgs]] = None, ): """ Custom trainer with an additional parameter. @@ -248,31 +131,11 @@ def __init__( super().__init__(argv=argv, train_args=train_args) self.train_args: CustomTrainerArgs - if int(os.getenv('NCCL_DEBUG_MODE', '0')) == 1 and int(os.getenv("CODE_GEN", '0')) != 1: - torch.distributed.init_process_group( - backend='nccl', - timeout=timedelta(seconds=30), - ) - print(f"Rank {self.rank} | {__name__} | nccl timeout is set to 30s for debugging") - else: - torch.distributed.init_process_group( - backend='nccl', - timeout=timedelta(hours=2), - ) - - global SAVE_ITERVAL, EXPR_NAME, PT_LOG_SAVE_DIR - SAVE_ITERVAL = save_data_steps - self.save_data_steps = save_data_steps - - EXPR_NAME = train_args.instance_name - PT_LOG_SAVE_DIR = os.path.join(train_args.log[0].args['root_dir'].replace('tf_logs', 'pt_logs'), EXPR_NAME) - os.makedirs(PT_LOG_SAVE_DIR, exist_ok=True) - - self.enable_prof = enable_prof - if self.enable_prof: - self.train_step_func = prof_train_step - else: - self.train_step_func = custom_train_step + torch.distributed.init_process_group( + backend='nccl', + timeout=timedelta(hours=2), + ) + self.train_step_func = custom_train_step def _train_epoch(self, epoch): VAL_STATUS_NO = 0 # not validated or saved @@ -331,13 +194,7 @@ def _train_epoch(self, epoch): global ITERATOR_COUNTER, ITER_BATCH_IDX_DICT ITERATOR_COUNTER[self.rank] = idx ITER_BATCH_IDX_DICT[self.rank] = {idx: 0} - # print(f"|{__name__}| rank={self.rank}, ITERATOR_COUNTER[self.rank]={ITERATOR_COUNTER[self.rank]}") - if self.rank == 0: - # looks manually update progress bar is easier - # than using tqdm directly - # the difference is we update progress bar at the beginning of the loop - # instead of the end of the loop progress.update(1) step_start_at = time.perf_counter() step_stat = _StepStat() @@ -353,13 +210,7 @@ def _train_epoch(self, epoch): self.hook.after_zero_grad(self) self.hook.on_train_step_start(self, batches[:num_batches], idx) - # losses = self.model.train_step(batches, is_dummy_batch) - losses, latencies = self.train_step_func(self.model, self.rank, idx, batches, is_dummy_batch) - # EXECUTOR.submit( - # save_iter_losses, - # idx, losses, latencies - # ) - # save_iter_losses(idx, losses, latencies) + losses, latencies = self.train_step_func(self.model, self.rank, idx, batches, is_dummy_batch) self.hook.on_train_step_end(self, losses[:num_batches], batches[:num_batches], idx) aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs @@ -373,9 +224,6 @@ def _train_epoch(self, epoch): self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) self.hook.before_sync_grad(self) - # actually `sync_shard_grad` is no-op here - # because trainer only supports end2end model - # and syncing grad in end2end model is done in `_train_step`. self.optimizer.sync_shard_grad() self.hook.after_sync_grad(self) @@ -584,11 +432,7 @@ def _create_model(): self.model = pmodel_class() self.model.cuda() self.optimizer = self.train_args.create_parallel_optimizer(self.model) - # Here we carefully scale down the gradient locally with 1/scale_factor before reduce, - # (the reduce op is `sum` by default, follow torch's c10d, grad is divided by scaling_factor before allreduce) - # and scale up the gradient after reduce - # (see `train_args.optimizer.grad_reduction`` handling in `train_epoch`). - # This is useful to avoid overflow when the gradients are large. + def reducer_pre_hook(reducer, grad): grad.div_(self.train_args.scaling_factor) self.optimizer.register_reducer_pre_hook(reducer_pre_hook) @@ -670,9 +514,4 @@ def _load_checkpoint(self): if self.lr_scheduler: self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) self.train_status = TrainStatus(**state_dict['train_status']) - - # Assume in efficiency measuring mode, no checkpoint is saved and we only need to load the originally trained checkpoint - if int(os.getenv("E2E_MEASURE", "0")) == 1 or int(os.getenv("COLLECT_RING_COMP_DATA", "0")) == 1: - self.train_status.finished_train_steps = 0 - self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() \ No newline at end of file diff --git a/mtraining/utils/expr_data.py b/mtraining/utils/expr_data.py new file mode 100644 index 0000000..e82f025 --- /dev/null +++ b/mtraining/utils/expr_data.py @@ -0,0 +1,15 @@ +import os +from dataclasses import dataclass + +@dataclass +class ExprData: + global_batch_size: int + micro_batch_size: int + reuse_type: str + +EXPR_DATA = ExprData(64, 1, "match") + +def update_expr_data(args): + global EXPR_DATA + EXPR_DATA = ExprData(args.global_batch_size, args.micro_batch_size, args.reuse_type) + # print(f"Updated global batch size to {EXPR_DATA.global_batch_size}, micro batch size to {EXPR_DATA.micro_batch_size}, reuse type to {EXPR_DATA.reuse_type}") diff --git a/mtraining/utils/general.py b/mtraining/utils/general.py index e5e1869..28ec24f 100644 --- a/mtraining/utils/general.py +++ b/mtraining/utils/general.py @@ -135,3 +135,70 @@ def freeze_model_params(model, active_param_config_name: str, prefix=""): print(f"keep active: {keep_active}") freeze_model_params_(model, keep_active, prefix) + +def get_resume_path( + check_resume: bool, + resume_from: str, + ckpt_save_dir: str, + num_gpus: int, +): + if not check_resume: + return None + elif resume_from is not None: + return resume_from + + # Detect the last checkpoint in CKPT_PATH + ckpt_dirs = [ckpt_dir for ckpt_dir in os.listdir(ckpt_save_dir) if len(ckpt_dir.split('-')) == 2 and ckpt_dir.split('-')[0].isdigit()] + + # Filter out directories that do not contain number of ckpts (file name ending with .ckpt) equal to num_gpus (check by os.listdir) + filtered_ckpt_dirs = [] + for ckpt_dir in ckpt_dirs: + ckpt_dir_path = os.path.join(ckpt_save_dir, ckpt_dir) + if os.path.isdir(ckpt_dir_path): + ckpt_files = [f for f in os.listdir(ckpt_dir_path) if f.endswith('.ckpt')] + if len(ckpt_files) == num_gpus: + filtered_ckpt_dirs.append(ckpt_dir) + + print(f"get_resume_path | filtered_ckpt_dirs = {filtered_ckpt_dirs}") + if len(filtered_ckpt_dirs) == 0: + return None + + target_ckpt_dir = sorted(filtered_ckpt_dirs, key=lambda x: (int(x.split('-')[0]), int(x.split('-')[1])))[-1] + target_ckpt_dir = os.path.join(ckpt_save_dir, target_ckpt_dir) + print(f"get_resume_path | target_ckpt_dir = {target_ckpt_dir}") + return target_ckpt_dir + + + +def fix_model_state_dict(model, model_state_dict): + if isinstance(model, ParallelModule): + required_keys = list(model.dist_param_map.values()) + required_keys_under = {k[6:]: v for k, v in model.dist_param_map.items()} + else: + required_keys = model.state_dict().keys() + required_keys_under = {k.replace('.', '_'): k for k in required_keys} + + has_model_prefix = 'model' in model_state_dict + model_state_dict = model_state_dict if not has_model_prefix else model_state_dict['model'] + model_sd_copy = model_state_dict.copy() + + if dist.is_initialized() and dist.get_rank() % int(os.getenv("GPU_PER_NODE", "8")) == 0: + print(f"{__name__} | required_keys[:10]: {required_keys[:10]}") + print(f"{__name__} | required_keys_under: {required_keys_under}") + print(f"{__name__} | model_state_dict.keys()[:10]: {list(model_state_dict.keys())[:10]}") + + for k in model_state_dict.keys(): + model_sd_copy.pop(k) + + under_k_start = 0 if not has_model_prefix else 1 + under_k = '_'.join(k.split('_')[under_k_start:-1]) + + if under_k in required_keys_under: + model_sd_copy[required_keys_under[under_k]] = model_state_dict[k] + + if 'lm_head_weight' in required_keys_under: + for k in model_state_dict.keys(): + if 'model_embed_tokens_weight' in k: + model_sd_copy[required_keys_under['lm_head_weight']] = model_state_dict[k] + + return model_sd_copy diff --git a/mtraining/utils/paths.py b/mtraining/utils/paths.py index ba24d7e..8035f50 100644 --- a/mtraining/utils/paths.py +++ b/mtraining/utils/paths.py @@ -1,13 +1,9 @@ import os BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - +TRAIN_ATTN_CONFIG_DIR = os.path.join(BASE_DIR, 'train_attn_configs') ACTIVE_PARAM_CONFIG_DIR = os.path.join(BASE_DIR, "models", "active_param_configs") -MINFER_CONFIG_DIR = os.path.join(BASE_DIR, 'sparse_configs') -SPARSE_PATTERN_CONFIG_DIR = os.path.join(BASE_DIR, 'ops', 'minfer', 'configs') -SPARSE_HEAD_MAP_DIR = os.path.join(BASE_DIR, 'ops', 'minfer', 'sparse_head_maps') - EXPR_DATA_SAVE_PATH = { 'base_path': None, From 3cdbf7a92f2ed6f83385c39f686207a688821ce1 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Thu, 12 Jun 2025 09:54:36 +0000 Subject: [PATCH 03/12] restructured mtraining --- minference/dist_ops/moba_zigzag.py | 4 +- minference/dist_ops/op_utils/xattn_utils.py | 521 ------ minference/dist_ops/xattn_zigzag.py | 2 +- .../{dist_ops => ops}/op_utils/__init__.py | 0 .../{dist_ops => ops}/op_utils/moba_utils.py | 12 +- minference/ops/op_utils/xattn_utils.py | 592 +++++++ minference/ops/xattention_fa.py | 480 ++--- mtraining/.gitignore | 3 +- mtraining/attn_funcs/moba_func.py | 2 +- .../active_param_configs/attn_only.txt | 0 .../active_param_configs/qk_proj_only.txt | 0 .../scripts/train_qwen_mini_ProLong512K.sh | 102 ++ .../train_attn_configs/moba_256k_s95.yaml | 0 .../train_attn_configs/moba_512k_s95.yaml | 0 .../train_attn_configs/qwen_flex_090.yaml | 0 .../train_attn_configs/qwen_flex_095.yaml | 0 .../train_attn_configs/qwen_mf_dr_stripe.yaml | 0 .../train_attn_configs/qwen_mf_stripe.yaml | 0 .../train_attn_configs/qwen_mf_zigzag.yaml | 0 .../train_attn_configs/xattn_default.yaml | 0 .../train_attn_configs/xattn_zigzag_s16.yaml | 0 .../xattn_zigzag_s16_t85.yaml | 0 .../{models => model_configs}/__init__.py | 0 .../phi3/__init__.py | 0 .../phi3/configuration_phi3.py | 0 .../phi3/lc_config/configuration_phi3.py | 0 .../phi3/lc_config_mini/configuration_phi3.py | 0 .../phi3/modelling_phi.py | 0 .../qwen2/__init__.py | 0 .../qwen2/configuration_qwen2.py | 0 .../qwen2/lc_config/configuration_qwen2.py | 0 .../lc_config_mini/configuration_qwen2.py | 0 .../qwen2/modeling_qwen2.py | 0 mtraining/models/phi3/modelling_phi_legacy.py | 1568 ----------------- .../qwen2/mi_config/configuration_qwen2.py | 185 -- .../models/qwen2/mi_config/modeling_qwen2.py | 1490 ---------------- mtraining/models/qwen2/vllm_sparse_qwen2.py | 465 ----- mtraining/models/sparse_ops/.gitignore | 7 - .../mtraining_sparse_ops/__init__.py | 2 - .../mtraining_sparse_ops/minference_config.py | 23 - mtraining/models/sparse_ops/setup.py | 27 - mtraining/setup.py | 3 +- mtraining/setup.sh | 7 +- mtraining/train.py | 87 +- mtraining/trainer.py | 6 +- mtraining/utils/general.py | 16 +- mtraining/utils/paths.py | 2 - 47 files changed, 877 insertions(+), 4729 deletions(-) delete mode 100644 minference/dist_ops/op_utils/xattn_utils.py rename minference/{dist_ops => ops}/op_utils/__init__.py (100%) rename minference/{dist_ops => ops}/op_utils/moba_utils.py (98%) create mode 100644 minference/ops/op_utils/xattn_utils.py rename mtraining/{models => experiments}/active_param_configs/attn_only.txt (100%) rename mtraining/{models => experiments}/active_param_configs/qk_proj_only.txt (100%) create mode 100755 mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh rename mtraining/{ => experiments}/train_attn_configs/moba_256k_s95.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/moba_512k_s95.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/qwen_flex_090.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/qwen_flex_095.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/qwen_mf_dr_stripe.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/qwen_mf_stripe.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/qwen_mf_zigzag.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/xattn_default.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/xattn_zigzag_s16.yaml (100%) rename mtraining/{ => experiments}/train_attn_configs/xattn_zigzag_s16_t85.yaml (100%) rename mtraining/{models => model_configs}/__init__.py (100%) rename mtraining/{models => model_configs}/phi3/__init__.py (100%) rename mtraining/{models => model_configs}/phi3/configuration_phi3.py (100%) rename mtraining/{models => model_configs}/phi3/lc_config/configuration_phi3.py (100%) rename mtraining/{models => model_configs}/phi3/lc_config_mini/configuration_phi3.py (100%) rename mtraining/{models => model_configs}/phi3/modelling_phi.py (100%) rename mtraining/{models => model_configs}/qwen2/__init__.py (100%) rename mtraining/{models => model_configs}/qwen2/configuration_qwen2.py (100%) rename mtraining/{models => model_configs}/qwen2/lc_config/configuration_qwen2.py (100%) rename mtraining/{models => model_configs}/qwen2/lc_config_mini/configuration_qwen2.py (100%) rename mtraining/{models => model_configs}/qwen2/modeling_qwen2.py (100%) delete mode 100644 mtraining/models/phi3/modelling_phi_legacy.py delete mode 100644 mtraining/models/qwen2/mi_config/configuration_qwen2.py delete mode 100644 mtraining/models/qwen2/mi_config/modeling_qwen2.py delete mode 100644 mtraining/models/qwen2/vllm_sparse_qwen2.py delete mode 100644 mtraining/models/sparse_ops/.gitignore delete mode 100644 mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py delete mode 100644 mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py delete mode 100644 mtraining/models/sparse_ops/setup.py diff --git a/minference/dist_ops/moba_zigzag.py b/minference/dist_ops/moba_zigzag.py index c60b981..f3d60e4 100644 --- a/minference/dist_ops/moba_zigzag.py +++ b/minference/dist_ops/moba_zigzag.py @@ -8,19 +8,17 @@ from einops import rearrange from typing import List, Tuple, Dict -from time import perf_counter from flash_attn.flash_attn_interface import ( _flash_attn_varlen_forward, _flash_attn_varlen_backward, ) - from .utils import ( RingComm, update_out_and_lse, recover_zigzag_output, get_default_args, ) -from .op_utils.moba_utils import ( +from minference.ops.op_utils.moba_utils import ( shuffle_input_all, shuffle_input_only, compute_moba_gate ) diff --git a/minference/dist_ops/op_utils/xattn_utils.py b/minference/dist_ops/op_utils/xattn_utils.py deleted file mode 100644 index 444c36f..0000000 --- a/minference/dist_ops/op_utils/xattn_utils.py +++ /dev/null @@ -1,521 +0,0 @@ -import math -import torch -import torch.nn.functional as F -import torch.distributed as dist - -from minference.dist_ops.utils import RingComm -from minference.ops.xattention_fa import ( - softmax_fuse_block_sum, - flat_group_gemm_fuse_reshape, -) - - -LN2 = 1 / 1.4426950408889634 -def create_causal_mask(batch_size, head_num, block_size, block_num, divide_block_num): - """ - Creates a causal attention mask used in transformer-based models. - - Parameters: - - batch_size (int): The number of sequences in the batch. - - head_num (int): The number of attention heads. - - block_size (int): The size of each block in the sequence. - - block_num (int): The total number of blocks in the sequence. - - divide_block_num (int): The block index at which causality is applied. - - Returns: - - torch.Tensor: A mask tensor of shape (batch_size, head_num, block_size, total_size) - where total_size = block_size * block_num. The mask enforces causal attention by - setting certain positions to `-inf` to prevent information leakage from future tokens. - """ - divide_block_num += 1 - if divide_block_num < 1 or divide_block_num > block_num: - raise ValueError( - f"divide_block_num ({divide_block_num}) must be between 1 and block_num ({block_num})." - ) - - total_size = block_size * block_num - device = "cuda" - mask = torch.zeros(block_size, total_size, device=device) - if divide_block_num < block_num: - mask[:, divide_block_num * block_size :] = float("-inf") - - if divide_block_num - 1 < block_num: - start_col = (divide_block_num - 1) * block_size - end_col = start_col + block_size - upper_tri_mask = torch.triu( - torch.full((block_size, block_size), float("-inf"), device=device), - diagonal=1, - ) - mask[:, start_col:end_col] = upper_tri_mask - - mask = mask.unsqueeze(0).unsqueeze(0) - mask = mask.expand(batch_size, head_num, block_size, total_size) - return mask - -def find_blocks_chunked( - input_tensor: torch.Tensor, # (batch_size, num_heads, num_block_q, num_block_k) - current_index, # - threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True -): - """ - Finds and selects relevant blocks of attention for transformer-based models based on a - threshold or a predefined number of blocks. - - Parameters: - - input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, num_block_q, num_block_k). - - current_index (int): The current index in the sequence processing. - - threshold (float or None): A threshold value used to determine the minimum attention weight sum. - - num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval. - - decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode. - - mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'. - - causal (bool): If True, applies causal masking to prevent future information leakage. - - Returns: - - torch.Tensor: A boolean mask of shape (batch_size, head_num, num_block_q, num_block_k), - indicating which blocks should be attended to. - """ - assert threshold is None or num_to_choose is None - batch_size, head_num, num_block_q, num_block_k = input_tensor.shape - input_tensor = input_tensor.to(float) - - total_sum = input_tensor.sum(dim=-1, keepdim=True) - if isinstance(threshold, torch.Tensor): - threshold = threshold.to(float) - required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze( - -1 - ).expand((batch_size, head_num, num_block_q, 1)).to(input_tensor.device) - else: - required_sum = total_sum * threshold - - - mask = torch.zeros_like(input_tensor, dtype=torch.bool) - mask[:, :, :, 0] = 1 - mask[:, :, :, current_index : current_index + num_block_q] = ( - torch.eye(num_block_q, device=mask.device) - .unsqueeze(0) - .unsqueeze(0) - .expand(1, head_num, num_block_q, num_block_q) - ) - # Note that other_values only contains the values of the current block - # (the sink blocks and diagonal are filled with 0) - other_values = input_tensor.masked_fill(mask, 0) - - - # Get sorted values - sorted_values, _ = torch.sort(other_values, dim=-1, descending=True) - sorted_values = sorted_values.to(input_tensor.device) - sorted_values = torch.cat( - [ - torch.zeros( - (batch_size, head_num, num_block_q, 1), device=input_tensor.device - ), - torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True), # shape: (batch_size, head_num, num_block_q, 1) - sorted_values[:, :, :, :-2], # :-2 excludes the first and diagonal (which are marked 0 in other_values) - ], - dim=-1, - ) - - # Get sorted indices - # index will select the already-masked (sink and diagonal) at the beginning - _, index = torch.sort( - torch.where(mask, 100000 * (1 + input_tensor), input_tensor), - dim=-1, - descending=True, - ) - - # [batch_size, head_num, num_block_q, num_block_k] - cumulative_sum_without_self = torch.cat( - [ - torch.zeros( - (batch_size, head_num, num_block_q, 1), device=input_tensor.device - ), - sorted_values[:, :, :, 0:-1], - ], - dim=-1, - ).cumsum(dim=-1) - - # Mask for indices where cumulative sum is below the required threshold. - index_mask = cumulative_sum_without_self < required_sum - index = torch.where(index_mask, index, 0) - - mask = mask.view(batch_size, head_num * num_block_q, num_block_k) - index = index.view(batch_size, head_num * num_block_q, num_block_k) - mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True - mask = mask.view(batch_size, head_num, num_block_q, num_block_k) - - - assert bool((torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True) >= required_sum * 0.99).all()), \ - f"mask sum {torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True)} < required_sum {required_sum}" - - try: - if causal: - assert (~mask[:, :, :, current_index + num_block_q :]).all() - except: - mask[:, :, :, current_index + num_block_q :] = False - - if causal: - if decoding: - assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() - else: - lambda_mask = torch.zeros_like(input_tensor,dtype=bool,device=input_tensor.device) - lambda_mask[:,:,:,0] = 1 - lambda_mask[:,:,:,current_index:current_index+num_block_q] = torch.eye(num_block_q, device=lambda_mask.device).unsqueeze(0).unsqueeze(0).expand(1,head_num,num_block_q,num_block_q) - assert(torch.where(lambda_mask,mask,True).all()) - - return mask - - -def xattn_estimate( - query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) - key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) - block_size, - stride, - norm=1, - softmax=True, - threshold=0.9, - chunk_size=16384, - select_mode="inverse", - use_triton=True, - causal=True, - kdb: int = 1, - keep_sink=False, - keep_recent=False, -) -> torch.Tensor: - batch_size, num_kv_head, k_len, head_dim = key_states.shape - batch_size, num_q_head, q_len, head_dim = query_states.shape - if num_q_head > num_kv_head: - key_states = torch.repeat_interleave(key_states.contiguous(), num_q_head // num_kv_head, dim=1) - - assert q_len % chunk_size == 0 - assert k_len % chunk_size == 0 - - q_chunk_num = q_len // chunk_size - q_block_num = q_len // block_size - - # assert num_kv_head == num_q_head - attn_sum_list = [] - simple_mask_list = [] - - if use_triton and ( - "100" not in torch.cuda.get_device_properties(torch.cuda.current_device()).name - ): - use_triton = False - print( - "setting use triton to false. Triton kernel not surpported on this device" - ) - - num_strides_in_k = k_len // stride - - num_strides_per_chunk = chunk_size // stride - num_strides_per_block = block_size // stride - num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block - - for chunk_idx in range(q_chunk_num): - if kdb != 1: - raise ValueError("use_triton and kdb cannot be used together") - - q_chunk_start = chunk_idx * num_strides_per_chunk * stride - q_chunk_end = (chunk_idx + 1) * num_strides_per_chunk * stride - - q_chunk_start_stride = chunk_idx * num_strides_per_chunk - q_chunk_end_stride = (chunk_idx + 1) * num_strides_per_chunk - - # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) - # (i.e. the attention sum of each SxS stride block) - # This step is agnostic to block size and just computes the attention sum in each stride block - attn_weights_slice = flat_group_gemm_fuse_reshape( - # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True - query_states[:, :, q_chunk_start : q_chunk_end, :,], - key_states, - stride, - q_chunk_start_stride, - q_chunk_end_stride, - is_causal=causal, - ) - - # (batch_size, num_heads, q_block_num, k_block_num), - attn_sum = softmax_fuse_block_sum( - attn_weights_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) - num_strides_per_block, - min(4096, num_strides_per_block), - q_chunk_start_stride, q_chunk_end_stride, - num_strides_in_k, - 1 / LN2 / math.sqrt(head_dim) / stride / norm, - is_causal=causal, - ) - - - # (batch_size, head_num, num_blocks_per_chunk, block_num) - simple_mask = find_blocks_chunked( - attn_sum, - chunk_idx * num_blocks_per_chunk, - threshold, - None, - decoding=False, - mode="prefill", - causal=causal, - ) - - attn_sum_list.append(attn_sum) - simple_mask_list.append(simple_mask) - - del attn_weights_slice - - attn_sums = torch.cat(attn_sum_list, dim=-2) - - # (batch_size, head_num, num_blocks_per_chunk * q_chunk_num, block_num) - # i.e. (batch_size, head_num, q_block_num, q_block_num) - simple_masks = torch.cat(simple_mask_list, dim=-2) - - if causal: - simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( - torch.tril( - torch.ones( - q_block_num, q_block_num, dtype=bool, device=key_states.device - ), - diagonal=0, - ), - simple_masks[:, :, -q_block_num:, -q_block_num:], - False, - ) - # print(f"{__name__} | simple_masks[:, :, -q_block_num:, -q_block_num:].shape {simple_masks[:, :, -q_block_num:, -q_block_num:].shape} after torch.where") - - - if keep_sink: - simple_masks[:, :, 0, :] = True - if keep_recent: - eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool) - eye_matrix_expanded = ( - eye_matrix.unsqueeze(0) - .unsqueeze(0) - .expand(1, num_kv_head, q_block_num, q_block_num) - ) - simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( - eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:] - ) - - # simple_masks -> (batch_size, head_num, q_block_num, q_block_num) - return attn_sums, simple_masks - -def check_device(use_triton: bool): - avail = use_triton and ( - "100" not in torch.cuda.get_device_properties(torch.cuda.current_device()).name - ) - if not avail: - print("Setting use triton to false. Triton kernel not surpported on this device") - return avail - - - -def xattn_zigzag_estimate( - query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) - key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) - block_size, - stride, - norm=1, - softmax=True, - threshold=0.9, - select_mode="inverse", - use_triton=True, - causal=True, - kdb: int = 1, - keep_sink=False, - keep_recent=False, - group: dist.group = None, -) -> torch.Tensor: - batch_size, num_kv_head, k_len_local, head_dim = key_states.shape - batch_size, num_q_head, q_len_local, head_dim = query_states.shape - - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - k_gather_list = [torch.empty_like(key_states) for _ in range(world_size)] - dist.all_gather(k_gather_list, key_states.contiguous(), group=group) - k_gathered = torch.cat(k_gather_list, dim=2) - k_len = k_gathered.shape[2] - - if num_q_head > num_kv_head: - k_gathered = torch.repeat_interleave(k_gathered.contiguous(), num_q_head // num_kv_head, dim=1) - - chunk_size = q_len_local // 2 - q_chunk_num = 2 - q_block_num = q_len_local // block_size - q_block_num_per_chunk = chunk_size // block_size - - # assert num_kv_head == num_q_head - attn_sum_list = [] - simple_mask_list = [] - - num_strides_in_k = k_len // stride - num_strides_per_chunk = chunk_size // stride - num_strides_per_block = block_size // stride - num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block - - attn_weight_slices = [None, None] - for chunk_idx in range(q_chunk_num): - global_chunk_idx = rank * 2 + chunk_idx - - # Local start index - q_chunk_start = chunk_idx * chunk_size - q_chunk_end = (chunk_idx + 1) * chunk_size - - # Global start index (stride-level) - q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk - q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk - - # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) - # (i.e. the attention sum of each SxS stride block) - # This step is agnostic to block size and just computes the attention sum in each stride block - attn_weight_slice = flat_group_gemm_fuse_reshape( - # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True - query_states[:, :, q_chunk_start : q_chunk_end, :,], - k_gathered, - stride, - q_chunk_start_stride_global, q_chunk_end_stride_global, - is_causal=causal, - ) - attn_weight_slices[chunk_idx] = attn_weight_slice - del k_gathered, k_gather_list - - for chunk_idx in range(q_chunk_num): - global_chunk_idx = rank * 2 + chunk_idx - - # Local start index - q_chunk_start = chunk_idx * chunk_size - q_chunk_end = (chunk_idx + 1) * chunk_size - - # Global start index (block-level) - q_block_start = global_chunk_idx * q_block_num_per_chunk - q_block_end = (global_chunk_idx + 1) * q_block_num_per_chunk - - # Global start index (stride-level) - q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk - q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk - - attn_weight_slice = attn_weight_slices[chunk_idx] - - # (batch_size, num_heads, q_block_num, k_block_num), - attn_sum = softmax_fuse_block_sum( - attn_weight_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) - num_strides_per_block, - min(4096, num_strides_per_block), - q_chunk_start_stride_global, q_chunk_end_stride_global, - num_strides_in_k, - 1 / LN2 / math.sqrt(head_dim) / stride / norm, - is_causal=causal, - ) - - # (batch_size, head_num, num_blocks_per_chunk, block_num) - simple_mask = find_blocks_chunked( - attn_sum, - global_chunk_idx * num_blocks_per_chunk, - threshold, - None, - decoding=False, - mode="prefill", - causal=causal, - ) - - del attn_weight_slice - if causal: - simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( - torch.tril( - torch.ones( - q_block_num_per_chunk, q_block_num_per_chunk, - dtype=bool, device=key_states.device - ), - diagonal=0, - ), - simple_mask[:, :, :, q_block_start:q_block_end], - False, - ) - simple_mask[:, :, :, q_block_end:] = 0 - if keep_sink: - simple_mask[:, :, 0, :] = True - if keep_recent: - eye_matrix = torch.eye(q_block_num_per_chunk, device=simple_mask.device, dtype=bool) - eye_matrix_expanded = ( - eye_matrix.unsqueeze(0) - .unsqueeze(0) - .expand(1, num_kv_head, q_block_num_per_chunk, q_block_num_per_chunk) - ) - simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( - eye_matrix_expanded, True, simple_mask[:, :, :, q_block_start:q_block_end] - ) - - attn_sum_list.append(attn_sum) - simple_mask_list.append(simple_mask) - - attn_sums = torch.cat(attn_sum_list, dim=-2) - simple_masks = torch.cat(simple_mask_list, dim=-2) # (batch_size, head_num, q_local_block_num, k_global_block_num) - return attn_sums, simple_masks - - -def shuffle_zigzag_masks( - block_masks: torch.Tensor, # [batch_size, num_qo_heads, num_blocks_local, num_blocks] - process_group: dist.ProcessGroup = None - ): - dim = len(block_masks.shape) - 1 - if not block_masks.is_contiguous(): - block_masks = block_masks.contiguous() - - # We must use outplace, otherwise it will raise error at backward due to inplace operations. - # We can not change to_send directly and create a new tensor to store the result. - to_send_f = torch.zeros_like(block_masks) - - # assume the input sequence length is 8, and computation runs on 4 GPUs - # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 - # the input status before `shuffle_zigzag_input` is - # - gpu A: [0 1] - # - gpu B: [2 3] - # - gpu C: [4 5] - # - gpu D: [6 7] - # the value of `to_send_slice` is - # - gpu A: [1] - # - gpu B: [3] - # - gpu C: [5] - # - gpu D: [7] - block_seq_len = block_masks.shape[dim] // 2 - left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] - right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] - to_send_slice = block_masks[right_slicer].contiguous() - - rank = dist.get_rank(process_group) - world_size = dist.get_world_size(process_group) - - res = torch.zeros_like(to_send_slice) - - _ops = [] - offset = ((dist.get_rank() // world_size) * world_size) - # rank src_rank - # 0 3 - # 1 2 - # 2 1 - # 3 0 - src_rank = (world_size - rank - 1) % world_size + offset - send_op = dist.P2POp( - dist.isend, to_send_slice, src_rank, group=process_group - ) - recv_op = dist.P2POp( - dist.irecv, res, src_rank, group=process_group) - - _ops.append(send_op) - _ops.append(recv_op) - - response = dist.batch_isend_irecv(_ops) - for resp in response: - resp.wait() - - if rank >= world_size // 2: # D: 6 7, -> 1 6 - to_send_f[right_slicer] = block_masks[left_slicer] - to_send_f[left_slicer] = res - else: # A: 0 1, -> 0 7 - to_send_f[left_slicer] = block_masks[left_slicer] - to_send_f[right_slicer] = res - # after shuffle, the status of `to_send_f` - # GPU A: [0 7] - # GPU B: [2 5] - # GPU C: [3 4] - # GPU D: [1 6] - - return to_send_f diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py index b92d011..ed03719 100644 --- a/minference/dist_ops/xattn_zigzag.py +++ b/minference/dist_ops/xattn_zigzag.py @@ -10,9 +10,9 @@ shuffle_zigzag_input, recover_zigzag_output, shuffle_block_mask_zigzag, ) -from .op_utils.xattn_utils import LN2, find_blocks_chunked from minference.ops.utils import convert_blockmask +from minference.ops.op_utils.xattn_utils import LN2, find_blocks_chunked from minference.ops.minference_attn import block_attn_fwd, block_attn_bwd from minference.ops.minference_attn_triton import triton_block_attn_fwd, triton_block_attn_bwd from minference.ops.xattention_fa import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum diff --git a/minference/dist_ops/op_utils/__init__.py b/minference/ops/op_utils/__init__.py similarity index 100% rename from minference/dist_ops/op_utils/__init__.py rename to minference/ops/op_utils/__init__.py diff --git a/minference/dist_ops/op_utils/moba_utils.py b/minference/ops/op_utils/moba_utils.py similarity index 98% rename from minference/dist_ops/op_utils/moba_utils.py rename to minference/ops/op_utils/moba_utils.py index 8c981f1..8cf91d1 100644 --- a/minference/dist_ops/op_utils/moba_utils.py +++ b/minference/ops/op_utils/moba_utils.py @@ -2,19 +2,11 @@ # Licensed under the MIT License. # Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention -import os -import math import torch -import inspect -import operator -import contextlib -import pandas as pd -import torch.nn.functional as F import torch.distributed as dist -from dataclasses import dataclass -from functools import reduce, cache, lru_cache -from typing import Optional, Tuple, List, Dict +from functools import lru_cache +from dataclasses import dataclass @dataclass class MoBAConfig: diff --git a/minference/ops/op_utils/xattn_utils.py b/minference/ops/op_utils/xattn_utils.py new file mode 100644 index 0000000..f80d3e7 --- /dev/null +++ b/minference/ops/op_utils/xattn_utils.py @@ -0,0 +1,592 @@ +import math +import torch +import triton +import triton.language as tl +import torch.nn.functional as F +import torch.distributed as dist + +LN2 = 1 / 1.4426950408889634 +def create_causal_mask(batch_size, head_num, block_size, block_num, divide_block_num): + """ + Creates a causal attention mask used in transformer-based models. + + Parameters: + - batch_size (int): The number of sequences in the batch. + - head_num (int): The number of attention heads. + - block_size (int): The size of each block in the sequence. + - block_num (int): The total number of blocks in the sequence. + - divide_block_num (int): The block index at which causality is applied. + + Returns: + - torch.Tensor: A mask tensor of shape (batch_size, head_num, block_size, total_size) + where total_size = block_size * block_num. The mask enforces causal attention by + setting certain positions to `-inf` to prevent information leakage from future tokens. + """ + divide_block_num += 1 + if divide_block_num < 1 or divide_block_num > block_num: + raise ValueError( + f"divide_block_num ({divide_block_num}) must be between 1 and block_num ({block_num})." + ) + + total_size = block_size * block_num + device = "cuda" + mask = torch.zeros(block_size, total_size, device=device) + if divide_block_num < block_num: + mask[:, divide_block_num * block_size :] = float("-inf") + + if divide_block_num - 1 < block_num: + start_col = (divide_block_num - 1) * block_size + end_col = start_col + block_size + upper_tri_mask = torch.triu( + torch.full((block_size, block_size), float("-inf"), device=device), + diagonal=1, + ) + mask[:, start_col:end_col] = upper_tri_mask + + mask = mask.unsqueeze(0).unsqueeze(0) + mask = mask.expand(batch_size, head_num, block_size, total_size) + return mask + +def find_blocks_chunked( + input_tensor: torch.Tensor, # (batch_size, num_heads, num_block_q, num_block_k) + current_index, # + threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True +): + """ + Finds and selects relevant blocks of attention for transformer-based models based on a + threshold or a predefined number of blocks. + + Parameters: + - input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, num_block_q, num_block_k). + - current_index (int): The current index in the sequence processing. + - threshold (float or None): A threshold value used to determine the minimum attention weight sum. + - num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval. + - decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode. + - mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'. + - causal (bool): If True, applies causal masking to prevent future information leakage. + + Returns: + - torch.Tensor: A boolean mask of shape (batch_size, head_num, num_block_q, num_block_k), + indicating which blocks should be attended to. + """ + assert threshold is None or num_to_choose is None + batch_size, head_num, num_block_q, num_block_k = input_tensor.shape + input_tensor = input_tensor.to(float) + + total_sum = input_tensor.sum(dim=-1, keepdim=True) + if isinstance(threshold, torch.Tensor): + threshold = threshold.to(float) + required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ).expand((batch_size, head_num, num_block_q, 1)).to(input_tensor.device) + else: + required_sum = total_sum * threshold + + + mask = torch.zeros_like(input_tensor, dtype=torch.bool) + mask[:, :, :, 0] = 1 + mask[:, :, :, current_index : current_index + num_block_q] = ( + torch.eye(num_block_q, device=mask.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, num_block_q, num_block_q) + ) + # Note that other_values only contains the values of the current block + # (the sink blocks and diagonal are filled with 0) + other_values = input_tensor.masked_fill(mask, 0) + + + # Get sorted values + sorted_values, _ = torch.sort(other_values, dim=-1, descending=True) + sorted_values = sorted_values.to(input_tensor.device) + sorted_values = torch.cat( + [ + torch.zeros( + (batch_size, head_num, num_block_q, 1), device=input_tensor.device + ), + torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True), # shape: (batch_size, head_num, num_block_q, 1) + sorted_values[:, :, :, :-2], # :-2 excludes the first and diagonal (which are marked 0 in other_values) + ], + dim=-1, + ) + + # Get sorted indices + # index will select the already-masked (sink and diagonal) at the beginning + _, index = torch.sort( + torch.where(mask, 100000 * (1 + input_tensor), input_tensor), + dim=-1, + descending=True, + ) + + # [batch_size, head_num, num_block_q, num_block_k] + cumulative_sum_without_self = torch.cat( + [ + torch.zeros( + (batch_size, head_num, num_block_q, 1), device=input_tensor.device + ), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + + # Mask for indices where cumulative sum is below the required threshold. + index_mask = cumulative_sum_without_self < required_sum + index = torch.where(index_mask, index, 0) + + mask = mask.view(batch_size, head_num * num_block_q, num_block_k) + index = index.view(batch_size, head_num * num_block_q, num_block_k) + mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True + mask = mask.view(batch_size, head_num, num_block_q, num_block_k) + + + assert bool((torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True) >= required_sum * 0.99).all()), \ + f"mask sum {torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True)} < required_sum {required_sum}" + + try: + if causal: + assert (~mask[:, :, :, current_index + num_block_q :]).all() + except: + mask[:, :, :, current_index + num_block_q :] = False + + if causal: + if decoding: + assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() + else: + lambda_mask = torch.zeros_like(input_tensor,dtype=bool,device=input_tensor.device) + lambda_mask[:,:,:,0] = 1 + lambda_mask[:,:,:,current_index:current_index+num_block_q] = torch.eye(num_block_q, device=lambda_mask.device).unsqueeze(0).unsqueeze(0).expand(1,head_num,num_block_q,num_block_q) + assert(torch.where(lambda_mask,mask,True).all()) + + return mask + + +def shuffle_zigzag_masks( + block_masks: torch.Tensor, # [batch_size, num_qo_heads, num_blocks_local, num_blocks] + process_group: dist.ProcessGroup = None + ): + dim = len(block_masks.shape) - 1 + if not block_masks.is_contiguous(): + block_masks = block_masks.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(block_masks) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_zigzag_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = block_masks.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + to_send_slice = block_masks[right_slicer].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[right_slicer] = block_masks[left_slicer] + to_send_f[left_slicer] = res + else: # A: 0 1, -> 0 7 + to_send_f[left_slicer] = block_masks[left_slicer] + to_send_f[right_slicer] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + + + +@triton.jit +def softmax_fuse_block_sum_kernel_causal( + In, + Out, + scale, + input_stride_0, + input_stride_1, + input_stride_2, + output_stride_0, + output_stride_1, + output_stride_2, + real_q_len, + k_len, # we assume k_len is divisible by chunk size + chunk_start, + chunk_end, + segment_size: tl.constexpr, + block_size: tl.constexpr, +): + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + for iter in range(0, num_iters_before_causal): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + for iter in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) + X = tl.where(mask, X, -1.0e6) + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + l_i_inv = 1.0 / l_i + + sum_mask = offs_q[:, None] < real_q_len + + for iter in range(0, num_iters_before_causal): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + for iter in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) + X = tl.where(mask, X, -1.0e6) + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + for iter in range(num_iters_before_causal + 1, num_iters): + X = tl.zeros([segment_size // block_size], dtype=tl.float32) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + +@triton.jit +def softmax_fuse_block_sum_kernel_non_causal( + In, + Out, + scale, + input_stride_0, + input_stride_1, + input_stride_2, + output_stride_0, + output_stride_1, + output_stride_2, + real_q_len, + k_len, # we assume k_len is divisible by chunk size + chunk_start, + chunk_end, + segment_size: tl.constexpr, + block_size: tl.constexpr, +): + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + for iter in range(0, num_iters): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + l_i_inv = 1.0 / l_i + + sum_mask = offs_q[:, None] < real_q_len + + for iter in range(0, num_iters): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + +@triton.jit +def flat_group_gemm_kernel(Q, K, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_oz, stride_oh, stride_on, + chunk_start, chunk_end, + H: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * stride_qn + K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * stride_kn + + Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_qn + tl.arange(0, BLOCK_K)[None, :] + K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * stride_kn + tl.arange(0, BLOCK_K)[:, None] + + num_iters = HEAD_DIM // BLOCK_K + o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for iter in range(num_iters): + q = tl.load(Q_ptrs + iter * BLOCK_K) + k = tl.load(K_ptrs + iter * BLOCK_K) + o += tl.dot(q, k) + + O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + + tl.store(O_ptrs, o.to(Out.type.element_ty)) + +@triton.jit +def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_oz, stride_oh, stride_on, + chunk_start, chunk_end, + H: tl.constexpr, + STRIDE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + is_caual: tl.constexpr, + ): + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + if is_caual: + if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn + K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn + + Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1) + K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None] + + o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for iter in range(STRIDE): + q = tl.load(Q_ptrs - iter * stride_qn) + k = tl.load(K_ptrs + iter * stride_kn) + o += tl.dot(q, k) + + O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + + tl.store(O_ptrs, o.to(Out.type.element_ty)) + + +def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True): + batch_size, num_heads, q_len, k_len = attn_weights_slice.shape + assert q_len % reshaped_block_size == 0 + try: + assert k_len % segment_size == 0 + except: + assert False, f"xAttention error, k_len: {k_len}, segment size: {segment_size}" + assert segment_size % reshaped_block_size == 0 + assert attn_weights_slice.stride(-1) == 1 + + output = torch.empty((batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size), dtype=attn_weights_slice.dtype, device=attn_weights_slice.device) + + grid = (q_len // reshaped_block_size, num_heads, batch_size) + + if is_causal: + softmax_fuse_block_sum_kernel_causal[grid]( + attn_weights_slice, + output, + scale, + attn_weights_slice.stride(0), + attn_weights_slice.stride(1), + attn_weights_slice.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + real_q_len, + k_len, + chunk_start, + chunk_end, + segment_size, + reshaped_block_size, + ) + else: + softmax_fuse_block_sum_kernel_non_causal[grid]( + attn_weights_slice, + output, + scale, + attn_weights_slice.stride(0), + attn_weights_slice.stride(1), + attn_weights_slice.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + real_q_len, + k_len, + chunk_start, + chunk_end, + segment_size, + reshaped_block_size, + ) + + return output + +def flat_group_gemm(query_states, key_states, chunk_start, chunk_end): + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[2] + + output = torch.empty((batch_size, num_heads, q_len, kv_len), dtype=query_states.dtype, device=query_states.device) + BLOCK_M = 128 + BLOCK_N = 128 + BLOCK_K = 64 + + grid = (q_len // BLOCK_M, kv_len // BLOCK_N, batch_size * num_heads) + flat_group_gemm_kernel[grid]( + query_states, + key_states, + output, + query_states.stride(0), + query_states.stride(1), + query_states.stride(2), + key_states.stride(0), + key_states.stride(1), + key_states.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + chunk_start, + chunk_end, + num_heads, + head_dim, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + + return output + +def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True): + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[2] + + assert (key_states.shape[0] == batch_size) + assert (key_states.shape[1] == num_heads) + assert (key_states.shape[3] == head_dim) + + output = torch.empty((batch_size, num_heads, q_len // stride, kv_len // stride), dtype=query_states.dtype, device=query_states.device) + BLOCK_M = 128 + BLOCK_N = 128 + assert (q_len % (stride * BLOCK_M) == 0) + assert (kv_len % (stride * BLOCK_N) == 0) + + grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) + flat_group_gemm_fuse_reshape_kernel[grid]( + query_states, + key_states, + output, + query_states.stride(0), + query_states.stride(1), + query_states.stride(2), + key_states.stride(0), + key_states.stride(1), + key_states.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + chunk_start, + chunk_end, + num_heads, + stride, + head_dim, + BLOCK_M, + BLOCK_N, + is_causal, + ) + + return output \ No newline at end of file diff --git a/minference/ops/xattention_fa.py b/minference/ops/xattention_fa.py index b24f829..d565adb 100644 --- a/minference/ops/xattention_fa.py +++ b/minference/ops/xattention_fa.py @@ -1,374 +1,146 @@ # Copyright (c) 2025 Microsoft # Licensed under The MIT License [see LICENSE for details] # Refer to the code in https://github.com/mit-han-lab/x-attention - +import math import torch -import triton -import triton.language as tl from typing import List, Tuple, Dict, Any -from minference.dist_ops.op_utils.xattn_utils import xattn_estimate from minference.ops.minference_attn import block_attn_fwd, block_attn_bwd +from .op_utils.xattn_utils import ( + LN2, find_blocks_chunked, flat_group_gemm_fuse_reshape, softmax_fuse_block_sum +) + +def xattn_estimate( + query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) + key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) + block_size, + stride, + norm=1, + softmax=True, + threshold=0.9, + chunk_size=16384, + select_mode="inverse", + use_triton=True, + causal=True, + kdb: int = 1, + keep_sink=False, + keep_recent=False, +) -> torch.Tensor: + batch_size, num_kv_head, k_len, head_dim = key_states.shape + batch_size, num_q_head, q_len, head_dim = query_states.shape + if num_q_head > num_kv_head: + key_states = torch.repeat_interleave(key_states.contiguous(), num_q_head // num_kv_head, dim=1) + + assert q_len % chunk_size == 0 + assert k_len % chunk_size == 0 + + q_chunk_num = q_len // chunk_size + q_block_num = q_len // block_size + + # assert num_kv_head == num_q_head + attn_sum_list = [] + simple_mask_list = [] + + if use_triton and ( + "100" not in torch.cuda.get_device_properties(torch.cuda.current_device()).name + ): + use_triton = False + print( + "setting use triton to false. Triton kernel not surpported on this device" + ) -@triton.jit -def softmax_fuse_block_sum_kernel_causal( - In, - Out, - scale, - input_stride_0, - input_stride_1, - input_stride_2, - output_stride_0, - output_stride_1, - output_stride_2, - real_q_len, - k_len, # we assume k_len is divisible by chunk size - chunk_start, - chunk_end, - segment_size: tl.constexpr, - block_size: tl.constexpr, -): - block_id = tl.program_id(0) - head_id = tl.program_id(1) - batch_id = tl.program_id(2) - - offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size - offs_k = tl.arange(0, segment_size) - - num_iters = k_len // segment_size - num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size - - m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") - l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 - - input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 - input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 - - output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 - output_ptr = output_ptr + tl.arange(0, segment_size // block_size) - - for iter in range(0, num_iters_before_causal): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - m_local = tl.max(X, 1) - m_new = tl.maximum(m_i, m_local) - alpha = tl.math.exp2(m_i - m_new) - - X = X - m_new[:, None] - l_local = tl.sum(tl.math.exp2(X), 1) - l_i = l_i * alpha + l_local - - m_i = m_new - - for iter in range(num_iters_before_causal, num_iters_before_causal + 1): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) - X = tl.where(mask, X, -1.0e6) - m_local = tl.max(X, 1) - m_new = tl.maximum(m_i, m_local) - alpha = tl.math.exp2(m_i - m_new) - - X = X - m_new[:, None] - l_local = tl.sum(tl.math.exp2(X), 1) - l_i = l_i * alpha + l_local - - m_i = m_new - - l_i_inv = 1.0 / l_i - - sum_mask = offs_q[:, None] < real_q_len - - for iter in range(0, num_iters_before_causal): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] - X = tl.where(sum_mask, X, 0) - X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) - X = tl.sum(X, 2) - X = tl.sum(X, 0) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - - for iter in range(num_iters_before_causal, num_iters_before_causal + 1): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) - X = tl.where(mask, X, -1.0e6) - X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] - X = tl.where(sum_mask, X, 0) - X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) - X = tl.sum(X, 2) - X = tl.sum(X, 0) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - - for iter in range(num_iters_before_causal + 1, num_iters): - X = tl.zeros([segment_size // block_size], dtype=tl.float32) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - - -@triton.jit -def softmax_fuse_block_sum_kernel_non_causal( - In, - Out, - scale, - input_stride_0, - input_stride_1, - input_stride_2, - output_stride_0, - output_stride_1, - output_stride_2, - real_q_len, - k_len, # we assume k_len is divisible by chunk size - chunk_start, - chunk_end, - segment_size: tl.constexpr, - block_size: tl.constexpr, -): - block_id = tl.program_id(0) - head_id = tl.program_id(1) - batch_id = tl.program_id(2) - - offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size - offs_k = tl.arange(0, segment_size) - - num_iters = k_len // segment_size - - m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") - l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 - - input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 - input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 - - output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 - output_ptr = output_ptr + tl.arange(0, segment_size // block_size) - - for iter in range(0, num_iters): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - m_local = tl.max(X, 1) - m_new = tl.maximum(m_i, m_local) - alpha = tl.math.exp2(m_i - m_new) - - X = X - m_new[:, None] - l_local = tl.sum(tl.math.exp2(X), 1) - l_i = l_i * alpha + l_local - - m_i = m_new - - l_i_inv = 1.0 / l_i - - sum_mask = offs_q[:, None] < real_q_len - - for iter in range(0, num_iters): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] - X = tl.where(sum_mask, X, 0) - X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) - X = tl.sum(X, 2) - X = tl.sum(X, 0) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - -@triton.jit -def flat_group_gemm_kernel(Q, K, Out, - stride_qz, stride_qh, stride_qn, - stride_kz, stride_kh, stride_kn, - stride_oz, stride_oh, stride_on, - chunk_start, chunk_end, - H: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ): - block_m = tl.program_id(0).to(tl.int64) - block_n = tl.program_id(1).to(tl.int64) - batch_id = tl.program_id(2).to(tl.int64) // H - head_id = tl.program_id(2).to(tl.int64) % H - - if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: - return - - Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * stride_qn - K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * stride_kn - - Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_qn + tl.arange(0, BLOCK_K)[None, :] - K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * stride_kn + tl.arange(0, BLOCK_K)[:, None] - - num_iters = HEAD_DIM // BLOCK_K - o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - for iter in range(num_iters): - q = tl.load(Q_ptrs + iter * BLOCK_K) - k = tl.load(K_ptrs + iter * BLOCK_K) - o += tl.dot(q, k) - - O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N - O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] - - tl.store(O_ptrs, o.to(Out.type.element_ty)) - -@triton.jit -def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, - stride_qz, stride_qh, stride_qn, - stride_kz, stride_kh, stride_kn, - stride_oz, stride_oh, stride_on, - chunk_start, chunk_end, - H: tl.constexpr, - STRIDE: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - is_caual: tl.constexpr, - ): - block_m = tl.program_id(0).to(tl.int64) - block_n = tl.program_id(1).to(tl.int64) - batch_id = tl.program_id(2).to(tl.int64) // H - head_id = tl.program_id(2).to(tl.int64) % H - - if is_caual: - if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: - return - - Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn - K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn - - Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1) - K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None] - - o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - for iter in range(STRIDE): - q = tl.load(Q_ptrs - iter * stride_qn) - k = tl.load(K_ptrs + iter * stride_kn) - o += tl.dot(q, k) - - O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N - O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] - - tl.store(O_ptrs, o.to(Out.type.element_ty)) - - -def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True): - batch_size, num_heads, q_len, k_len = attn_weights_slice.shape - assert q_len % reshaped_block_size == 0 - try: - assert k_len % segment_size == 0 - except: - assert False, f"xAttention error, k_len: {k_len}, segment size: {segment_size}" - assert segment_size % reshaped_block_size == 0 - assert attn_weights_slice.stride(-1) == 1 - - output = torch.empty((batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size), dtype=attn_weights_slice.dtype, device=attn_weights_slice.device) - - grid = (q_len // reshaped_block_size, num_heads, batch_size) + num_strides_in_k = k_len // stride + + num_strides_per_chunk = chunk_size // stride + num_strides_per_block = block_size // stride + num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block + + for chunk_idx in range(q_chunk_num): + if kdb != 1: + raise ValueError("use_triton and kdb cannot be used together") + + q_chunk_start = chunk_idx * num_strides_per_chunk * stride + q_chunk_end = (chunk_idx + 1) * num_strides_per_chunk * stride + + q_chunk_start_stride = chunk_idx * num_strides_per_chunk + q_chunk_end_stride = (chunk_idx + 1) * num_strides_per_chunk + + # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) + # (i.e. the attention sum of each SxS stride block) + # This step is agnostic to block size and just computes the attention sum in each stride block + attn_weights_slice = flat_group_gemm_fuse_reshape( + # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True + query_states[:, :, q_chunk_start : q_chunk_end, :,], + key_states, + stride, + q_chunk_start_stride, + q_chunk_end_stride, + is_causal=causal, + ) - if is_causal: - softmax_fuse_block_sum_kernel_causal[grid]( - attn_weights_slice, - output, - scale, - attn_weights_slice.stride(0), - attn_weights_slice.stride(1), - attn_weights_slice.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - real_q_len, - k_len, - chunk_start, - chunk_end, - segment_size, - reshaped_block_size, + # (batch_size, num_heads, q_block_num, k_block_num), + attn_sum = softmax_fuse_block_sum( + attn_weights_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) + num_strides_per_block, + min(4096, num_strides_per_block), + q_chunk_start_stride, q_chunk_end_stride, + num_strides_in_k, + 1 / LN2 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, ) - else: - softmax_fuse_block_sum_kernel_non_causal[grid]( - attn_weights_slice, - output, - scale, - attn_weights_slice.stride(0), - attn_weights_slice.stride(1), - attn_weights_slice.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - real_q_len, - k_len, - chunk_start, - chunk_end, - segment_size, - reshaped_block_size, + + + # (batch_size, head_num, num_blocks_per_chunk, block_num) + simple_mask = find_blocks_chunked( + attn_sum, + chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, ) - return output - -def flat_group_gemm(query_states, key_states, chunk_start, chunk_end): - batch_size, num_heads, q_len, head_dim = query_states.shape - kv_len = key_states.shape[2] + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) - output = torch.empty((batch_size, num_heads, q_len, kv_len), dtype=query_states.dtype, device=query_states.device) - BLOCK_M = 128 - BLOCK_N = 128 - BLOCK_K = 64 + del attn_weights_slice - grid = (q_len // BLOCK_M, kv_len // BLOCK_N, batch_size * num_heads) - flat_group_gemm_kernel[grid]( - query_states, - key_states, - output, - query_states.stride(0), - query_states.stride(1), - query_states.stride(2), - key_states.stride(0), - key_states.stride(1), - key_states.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - chunk_start, - chunk_end, - num_heads, - head_dim, - BLOCK_M, - BLOCK_N, - BLOCK_K, - ) - - return output - -def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True): - batch_size, num_heads, q_len, head_dim = query_states.shape - kv_len = key_states.shape[2] - - assert (key_states.shape[0] == batch_size) - assert (key_states.shape[1] == num_heads) - assert (key_states.shape[3] == head_dim) + attn_sums = torch.cat(attn_sum_list, dim=-2) - output = torch.empty((batch_size, num_heads, q_len // stride, kv_len // stride), dtype=query_states.dtype, device=query_states.device) - BLOCK_M = 128 - BLOCK_N = 128 - assert (q_len % (stride * BLOCK_M) == 0) - assert (kv_len % (stride * BLOCK_N) == 0) + # (batch_size, head_num, num_blocks_per_chunk * q_chunk_num, block_num) + # i.e. (batch_size, head_num, q_block_num, q_block_num) + simple_masks = torch.cat(simple_mask_list, dim=-2) - grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) - flat_group_gemm_fuse_reshape_kernel[grid]( - query_states, - key_states, - output, - query_states.stride(0), - query_states.stride(1), - query_states.stride(2), - key_states.stride(0), - key_states.stride(1), - key_states.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - chunk_start, - chunk_end, - num_heads, - stride, - head_dim, - BLOCK_M, - BLOCK_N, - is_causal, - ) + if causal: + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril( + torch.ones( + q_block_num, q_block_num, dtype=bool, device=key_states.device + ), + diagonal=0, + ), + simple_masks[:, :, -q_block_num:, -q_block_num:], + False, + ) + # print(f"{__name__} | simple_masks[:, :, -q_block_num:, -q_block_num:].shape {simple_masks[:, :, -q_block_num:, -q_block_num:].shape} after torch.where") + + + if keep_sink: + simple_masks[:, :, 0, :] = True + if keep_recent: + eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool) + eye_matrix_expanded = ( + eye_matrix.unsqueeze(0) + .unsqueeze(0) + .expand(1, num_kv_head, q_block_num, q_block_num) + ) + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:] + ) - return output + # simple_masks -> (batch_size, head_num, q_block_num, q_block_num) + return attn_sums, simple_masks class XAttnFunc(torch.autograd.Function): @staticmethod diff --git a/mtraining/.gitignore b/mtraining/.gitignore index 6fc96e7..736265e 100644 --- a/mtraining/.gitignore +++ b/mtraining/.gitignore @@ -9,4 +9,5 @@ MTraining.egg-info/ output/ **/ring_attn_comp_data/ **/ring_attn_comp_data/ -**/ring_attn_pt_logs/ \ No newline at end of file +**/ring_attn_pt_logs/ +expr_data_store/ \ No newline at end of file diff --git a/mtraining/attn_funcs/moba_func.py b/mtraining/attn_funcs/moba_func.py index 41a5993..9b68bc5 100644 --- a/mtraining/attn_funcs/moba_func.py +++ b/mtraining/attn_funcs/moba_func.py @@ -10,9 +10,9 @@ from nnscaler.runtime.device import DeviceGroup from nnscaler.graph.parser.register import register_op +from minference.ops.op_utils.moba_utils import MoBAConfig from minference.ops.moba import moba_attn_varlen, moba_layer from minference.dist_ops.moba_zigzag import moba_zigzag_func -from minference.dist_ops.op_utils.moba_utils import MoBAConfig def load_moba_config(moba_config_dict: Dict[str, Any]): moba_config = MoBAConfig(**moba_config_dict) diff --git a/mtraining/models/active_param_configs/attn_only.txt b/mtraining/experiments/active_param_configs/attn_only.txt similarity index 100% rename from mtraining/models/active_param_configs/attn_only.txt rename to mtraining/experiments/active_param_configs/attn_only.txt diff --git a/mtraining/models/active_param_configs/qk_proj_only.txt b/mtraining/experiments/active_param_configs/qk_proj_only.txt similarity index 100% rename from mtraining/models/active_param_configs/qk_proj_only.txt rename to mtraining/experiments/active_param_configs/qk_proj_only.txt diff --git a/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh new file mode 100755 index 0000000..0075484 --- /dev/null +++ b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh @@ -0,0 +1,102 @@ +#!/usr/bin/bash +i=$(hostname | awk -F'-' '{print $2}') +NODE_RANK=$i +export NUM_NODES=1 +export REUSE_TYPE="match" + +export HF_TRUST_REMOTE_CODE=true +export HF_DATASETS_TRUST_REMOTE_CODE=true + +export MASTER_ADDR="node-0" +export MASTER_PORT="12345" + +export NNSCALER_HOME="/home/aiscuser/.conda/envs/mtrain/lib/python3.10/site-packages/nnscaler/autodist/" +export PYTHONPATH="${NNSCALER_HOME}:${PYTHONPATH}" + +# ----------------------------------------------- +# TODO: Basic Environment Settings +SEQUENCE_LENGTH=524288 +export GPU_NAME=A100 +export GPU_PER_NODE=4 +export WORLD_SIZE=4 +export GPU_SET="${GPU_NAME}_${WORLD_SIZE}" +export EXPR_HOME="/scratch/MInference/mtraining" +export NNSCALER_STORE="${EXPR_HOME}/experiments/expr_data_store/${GPU_SET}" +mkdir -p $NNSCALER_STORE +cd $EXPR_HOME + +# ------------------------------------------ +# /blob/nnscaler_store/MI300_8/minfer_qwen/qwen_fp090_512K_mini +export EXPR_DIR="minfer_qwen" # Name for the experiment set +export EXPR_NAME="qwen_fp090_512K_mini" # Name for the single experiment run +export MODEL_ID="Qwen/Qwen2.5-3B" +export DATASET_PATH="/scratch/nnscaler_store/prolong_fixed_filter_qwen_524288" +export MODEL_CONFIG_PATH="${EXPR_HOME}/model_configs/qwen2/lc_config_mini" +echo "Using model config path: $MODEL_CONFIG_PATH" +TRANSFER_CONFIG_DIR="none" +export TRAIN_ATTN_CONFIG_PATH="/scratch/MInference/mtraining/experiments/train_attn_configs/qwen_flex_090.yaml" +export ATTN_TYPE="minfer" + +# ------------------------------------------ +# Training Path settings +export TF_LOG_PATH="$NNSCALER_STORE/$EXPR_DIR/tf_logs" +export CKPT_PATH="$NNSCALER_STORE/$EXPR_DIR/$EXPR_NAME/checkpoints" +export COMPILE_PATH="$NNSCALER_STORE/compile_config/rank_${NODE_RANK}" +mkdir -p $TF_LOG_PATH +mkdir -p $CKPT_PATH +mkdir -p $COMPILE_PATH + +# ------------------------------------------- +# Training Settings +export GLOBAL_BATCH_SIZE=4 # TODO +export MICRO_BATCH_SIZE=1 +export MEM_CONSTRAINT=72 + +export NUM_ITER=10 +export NUM_EPOCH=0 + +export CKPT_SAVE_STEP=5 +export CKPT_SAVE_EPOCH=0 + +export CHECK_RESUME=1 +if [ "$CHECK_RESUME" -eq 1 ]; then + CHECK_RESUME="--check_resume" +else + CHECK_RESUME="" +fi + +# ------------------------------------------- +# Logging Path +export LOG_PATH="${NNSCALER_STORE}/${EXPR_DIR}/${EXPR_NAME}/rank_${NODE_RANK}" +mkdir -p $LOG_PATH +echo "Logging directed to $LOG_PATH/train.log" + +export TRACE_STRATEGY="reuse_cache" +torchrun --nproc_per_node=$GPU_PER_NODE \ + --nnodes=$NUM_NODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --master_port=$MASTER_PORT \ + train.py --plan_ngpus $WORLD_SIZE \ + --runtime_ngpus $WORLD_SIZE \ + --name $EXPR_NAME \ + --seq_len $SEQUENCE_LENGTH \ + --attn_type $ATTN_TYPE \ + --train_attn_config_path $TRAIN_ATTN_CONFIG_PATH \ + --reuse_type $REUSE_TYPE \ + --model_id $MODEL_ID \ + --n_iter $NUM_ITER \ + --n_epochs $NUM_EPOCH \ + --global_batch_size $GLOBAL_BATCH_SIZE \ + --micro_batch_size $MICRO_BATCH_SIZE \ + --dataset_path $DATASET_PATH \ + --compile_save_path $COMPILE_PATH \ + --tf_log_dir $TF_LOG_PATH \ + --model_config_path $MODEL_CONFIG_PATH \ + --ckpt_save_dir $CKPT_PATH \ + --ckpt_n_step $CKPT_SAVE_STEP \ + --ckpt_n_epoch $CKPT_SAVE_EPOCH \ + --trace_strategy $TRACE_STRATEGY \ + --transfer_config_dir $TRANSFER_CONFIG_DIR \ + --mem_constraint $MEM_CONSTRAINT \ + $CHECK_RESUME > $LOG_PATH/train.log 2>&1 \ No newline at end of file diff --git a/mtraining/train_attn_configs/moba_256k_s95.yaml b/mtraining/experiments/train_attn_configs/moba_256k_s95.yaml similarity index 100% rename from mtraining/train_attn_configs/moba_256k_s95.yaml rename to mtraining/experiments/train_attn_configs/moba_256k_s95.yaml diff --git a/mtraining/train_attn_configs/moba_512k_s95.yaml b/mtraining/experiments/train_attn_configs/moba_512k_s95.yaml similarity index 100% rename from mtraining/train_attn_configs/moba_512k_s95.yaml rename to mtraining/experiments/train_attn_configs/moba_512k_s95.yaml diff --git a/mtraining/train_attn_configs/qwen_flex_090.yaml b/mtraining/experiments/train_attn_configs/qwen_flex_090.yaml similarity index 100% rename from mtraining/train_attn_configs/qwen_flex_090.yaml rename to mtraining/experiments/train_attn_configs/qwen_flex_090.yaml diff --git a/mtraining/train_attn_configs/qwen_flex_095.yaml b/mtraining/experiments/train_attn_configs/qwen_flex_095.yaml similarity index 100% rename from mtraining/train_attn_configs/qwen_flex_095.yaml rename to mtraining/experiments/train_attn_configs/qwen_flex_095.yaml diff --git a/mtraining/train_attn_configs/qwen_mf_dr_stripe.yaml b/mtraining/experiments/train_attn_configs/qwen_mf_dr_stripe.yaml similarity index 100% rename from mtraining/train_attn_configs/qwen_mf_dr_stripe.yaml rename to mtraining/experiments/train_attn_configs/qwen_mf_dr_stripe.yaml diff --git a/mtraining/train_attn_configs/qwen_mf_stripe.yaml b/mtraining/experiments/train_attn_configs/qwen_mf_stripe.yaml similarity index 100% rename from mtraining/train_attn_configs/qwen_mf_stripe.yaml rename to mtraining/experiments/train_attn_configs/qwen_mf_stripe.yaml diff --git a/mtraining/train_attn_configs/qwen_mf_zigzag.yaml b/mtraining/experiments/train_attn_configs/qwen_mf_zigzag.yaml similarity index 100% rename from mtraining/train_attn_configs/qwen_mf_zigzag.yaml rename to mtraining/experiments/train_attn_configs/qwen_mf_zigzag.yaml diff --git a/mtraining/train_attn_configs/xattn_default.yaml b/mtraining/experiments/train_attn_configs/xattn_default.yaml similarity index 100% rename from mtraining/train_attn_configs/xattn_default.yaml rename to mtraining/experiments/train_attn_configs/xattn_default.yaml diff --git a/mtraining/train_attn_configs/xattn_zigzag_s16.yaml b/mtraining/experiments/train_attn_configs/xattn_zigzag_s16.yaml similarity index 100% rename from mtraining/train_attn_configs/xattn_zigzag_s16.yaml rename to mtraining/experiments/train_attn_configs/xattn_zigzag_s16.yaml diff --git a/mtraining/train_attn_configs/xattn_zigzag_s16_t85.yaml b/mtraining/experiments/train_attn_configs/xattn_zigzag_s16_t85.yaml similarity index 100% rename from mtraining/train_attn_configs/xattn_zigzag_s16_t85.yaml rename to mtraining/experiments/train_attn_configs/xattn_zigzag_s16_t85.yaml diff --git a/mtraining/models/__init__.py b/mtraining/model_configs/__init__.py similarity index 100% rename from mtraining/models/__init__.py rename to mtraining/model_configs/__init__.py diff --git a/mtraining/models/phi3/__init__.py b/mtraining/model_configs/phi3/__init__.py similarity index 100% rename from mtraining/models/phi3/__init__.py rename to mtraining/model_configs/phi3/__init__.py diff --git a/mtraining/models/phi3/configuration_phi3.py b/mtraining/model_configs/phi3/configuration_phi3.py similarity index 100% rename from mtraining/models/phi3/configuration_phi3.py rename to mtraining/model_configs/phi3/configuration_phi3.py diff --git a/mtraining/models/phi3/lc_config/configuration_phi3.py b/mtraining/model_configs/phi3/lc_config/configuration_phi3.py similarity index 100% rename from mtraining/models/phi3/lc_config/configuration_phi3.py rename to mtraining/model_configs/phi3/lc_config/configuration_phi3.py diff --git a/mtraining/models/phi3/lc_config_mini/configuration_phi3.py b/mtraining/model_configs/phi3/lc_config_mini/configuration_phi3.py similarity index 100% rename from mtraining/models/phi3/lc_config_mini/configuration_phi3.py rename to mtraining/model_configs/phi3/lc_config_mini/configuration_phi3.py diff --git a/mtraining/models/phi3/modelling_phi.py b/mtraining/model_configs/phi3/modelling_phi.py similarity index 100% rename from mtraining/models/phi3/modelling_phi.py rename to mtraining/model_configs/phi3/modelling_phi.py diff --git a/mtraining/models/qwen2/__init__.py b/mtraining/model_configs/qwen2/__init__.py similarity index 100% rename from mtraining/models/qwen2/__init__.py rename to mtraining/model_configs/qwen2/__init__.py diff --git a/mtraining/models/qwen2/configuration_qwen2.py b/mtraining/model_configs/qwen2/configuration_qwen2.py similarity index 100% rename from mtraining/models/qwen2/configuration_qwen2.py rename to mtraining/model_configs/qwen2/configuration_qwen2.py diff --git a/mtraining/models/qwen2/lc_config/configuration_qwen2.py b/mtraining/model_configs/qwen2/lc_config/configuration_qwen2.py similarity index 100% rename from mtraining/models/qwen2/lc_config/configuration_qwen2.py rename to mtraining/model_configs/qwen2/lc_config/configuration_qwen2.py diff --git a/mtraining/models/qwen2/lc_config_mini/configuration_qwen2.py b/mtraining/model_configs/qwen2/lc_config_mini/configuration_qwen2.py similarity index 100% rename from mtraining/models/qwen2/lc_config_mini/configuration_qwen2.py rename to mtraining/model_configs/qwen2/lc_config_mini/configuration_qwen2.py diff --git a/mtraining/models/qwen2/modeling_qwen2.py b/mtraining/model_configs/qwen2/modeling_qwen2.py similarity index 100% rename from mtraining/models/qwen2/modeling_qwen2.py rename to mtraining/model_configs/qwen2/modeling_qwen2.py diff --git a/mtraining/models/phi3/modelling_phi_legacy.py b/mtraining/models/phi3/modelling_phi_legacy.py deleted file mode 100644 index 10d1c7d..0000000 --- a/mtraining/models/phi3/modelling_phi_legacy.py +++ /dev/null @@ -1,1568 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" PyTorch Phi-3 model.""" - -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_phi3 import Phi3Config - -from nnscaler.graph.parser.register import register_op -from nnscaler.ir import IRTensor - - - -logger = logging.get_logger(__name__) - -# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements -# if is_flash_attn_2_available(): -_flash_supports_window_size = False -try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) -except ImportError as error: - logger.warning( - f"`flash-attention` package not found, consider installing for better performance: {error}." - ) - if not _flash_supports_window_size: - logger.warning( - "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`." - ) - -_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" -_CONFIG_FOR_DOC = "Phi3Config" - -PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "microsoft/Phi-3-mini-4k-instruct", - "microsoft/Phi-3-mini-128k-instruct", - # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 -] - - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 -class Phi3RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Phi3RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 -class Phi3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.register_buffer("inv_freq", None, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) - ) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding): - def __init__(self, dim, config, device=None): - super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) - - self.short_factor = config.rope_scaling["short_factor"] - self.long_factor = config.rope_scaling["long_factor"] - self.original_max_position_embeddings = config.original_max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - seq_len = torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) - - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) - - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Phi3MLP(nn.Module): - def __init__(self, config): - super().__init__() - - self.config = config - self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) - - self.activation_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - up_states = self.gate_up_proj(hidden_states) - - gate, up_states = up_states.chunk(2, dim=-1) - up_states = up_states * self.activation_fn(gate) - - return self.down_proj(up_states) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Phi3Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.original_max_position_embeddings = config.original_max_position_embeddings - self.rope_theta = config.rope_theta - self.rope_scaling = config.rope_scaling - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.rope_scaling is None: - self.rotary_emb = Phi3RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - if scaling_type == "longrope": - self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") - - bsz, q_len, _ = hidden_states.size() - - qkv = self.qkv_proj(hidden_states) - query_pos = self.num_heads * self.head_dim - query_states = qkv[..., :query_pos] - key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] - value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Phi3FlashAttention2(Phi3Attention): - """ - Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # Phi3FlashAttention2 attention does not support output_attentions - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." - ) - raise ValueError("The current flash attention version does not support sliding window attention.") - - output_attentions = False - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - bsz, q_len, _ = hidden_states.size() - - qkv = self.qkv_proj(hidden_states) - query_pos = self.num_heads * self.head_dim - query_states = qkv[..., :query_pos] - key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] - value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_dropout = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - if query_states.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.qkv_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=attn_dropout, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 -# TODO @Arthur no longer copied from LLama after static cache -class Phi3SdpaAttention(Phi3Attention): - """ - Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Phi3Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - qkv = self.qkv_proj(hidden_states) - query_pos = self.num_heads * self.head_dim - query_states = qkv[..., :query_pos] - key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] - value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -PHI3_ATTENTION_CLASSES = { - "eager": Phi3Attention, - "flash_attention_2": Phi3FlashAttention2, - "sdpa": Phi3SdpaAttention, -} - -class Phi3DecoderLayer(nn.Module): - def __init__(self, config: Phi3Config, layer_idx: int): - super().__init__() - - self.config = config - self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) - - self.mlp = Phi3MLP(config) - self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) - self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) - self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): - input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range - `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = residual + self.resid_attn_dropout(attn_outputs) - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + self.resid_mlp_dropout(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -PHI3_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Phi3Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", - PHI3_START_DOCSTRING, -) -class Phi3PreTrainedModel(PreTrainedModel): - config_class = Phi3Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Phi3DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = False - _supports_cache_class = True - - _version = "0.0.5" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -PHI3_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", - PHI3_START_DOCSTRING, -) -class Phi3Model(Phi3PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] - - Args: - config: Phi3Config - """ - - def __init__(self, config: Phi3Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.embed_dropout = nn.Dropout(config.embd_pdrop) - self.layers = nn.ModuleList( - [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Phi3ForCausalLM(Phi3PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 - def __init__(self, config): - super().__init__(config) - self.model = Phi3Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings - def get_input_embeddings(self): - return self.model.embed_tokens - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings - def get_output_embeddings(self): - return self.lm_head - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder - def set_decoder(self, decoder): - self.model = decoder - - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder - def get_decoder(self): - return self.model - - # Ignore copy - @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Phi3ForCausalLM - - >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") - - >>> prompt = "This is an example script ." - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The [`Phi3Model`] with a sequence classification head on top (linear layer). - - [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - PHI3_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs -class Phi3ForSequenceClassification(Phi3PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Phi3Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - model_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = model_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + model_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=model_outputs.past_key_values, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, - ) - - -@add_start_docstrings( - """ - [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - PHI3_START_DOCSTRING, -) -# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs -class Phi3ForTokenClassification(Phi3PreTrainedModel): - def __init__(self, config: Phi3Config): - super().__init__(config) - self.num_labels = config.num_labels - - self.model = Phi3Model(config) - if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: - classifier_dropout = config.classifier_dropout - elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - model_outputs = self.model( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = model_outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - batch_size, seq_length = labels.shape - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) - ) - - if not return_dict: - output = (logits,) + model_outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, - ) diff --git a/mtraining/models/qwen2/mi_config/configuration_qwen2.py b/mtraining/models/qwen2/mi_config/configuration_qwen2.py deleted file mode 100644 index 78e7d61..0000000 --- a/mtraining/models/qwen2/mi_config/configuration_qwen2.py +++ /dev/null @@ -1,185 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window if use_sliding_window else None - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_dropout = attention_dropout - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/mtraining/models/qwen2/mi_config/modeling_qwen2.py b/mtraining/models/qwen2/mi_config/modeling_qwen2.py deleted file mode 100644 index 253215f..0000000 --- a/mtraining/models/qwen2/mi_config/modeling_qwen2.py +++ /dev/null @@ -1,1490 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2 model.""" - -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - -assert is_flash_attn_2_available() -from flash_attn import flash_attn_with_kvcache - -from mtraining_sparse_ops import get_minference_config, minference_flash_attn_func -MINFERENCE_CONFIG = get_minference_config("Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json") - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B" -_CONFIG_FOR_DOC = "Qwen2Config" - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Qwen2Config] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - if query_states.shape[1] == key_states.shape[1]: # Prefilling - v_size, s_size = MINFERENCE_CONFIG[self.layer_idx] - attn_output = minference_flash_attn_func(query_states, key_states, value_states, v_size, s_size) - else: - attn_output = flash_attn_with_kvcache(query_states, key_states, value_states) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2RotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2 - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask |= sliding_attend_mask - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - QWEN2_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 -class Qwen2ForTokenClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - QWEN2_START_DOCSTRING, -) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2 -class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): - base_model_prefix = "model" - - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2 - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/mtraining/models/qwen2/vllm_sparse_qwen2.py b/mtraining/models/qwen2/vllm_sparse_qwen2.py deleted file mode 100644 index 4f28686..0000000 --- a/mtraining/models/qwen2/vllm_sparse_qwen2.py +++ /dev/null @@ -1,465 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py -# Copyright 2024 The Qwen team. -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple - -import torch -from torch import nn -from transformers import Qwen2Config - -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm import _custom_ops as ops - -from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers - -from flash_attn import flash_attn_func -from mtraining_sparse_ops import get_minference_config, minference_flash_attn_func -MINFERENCE_CONFIG = get_minference_config() - - -class Qwen2MLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class Qwen2Attention(nn.Module): - - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.layer_idx = 0 - self.max_num_tokens = max_position - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=self.rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - if attn_metadata.prefill_metadata and q.shape[0] < self.max_num_tokens: - # BATCH_SIZE == 1 - if kv_cache is not None: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - ops.reshape_and_cache_flash( - k.view(-1, self.num_kv_heads, self.head_dim), - v.view(-1, self.num_kv_heads, self.head_dim), - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.attn.kv_cache_dtype, - 1.0, - 1.0, - ) - v_size, s_size = MINFERENCE_CONFIG[self.layer_idx] - attn_output = minference_flash_attn_func( - q.reshape((1, -1, self.num_heads, self.head_dim)), - k.reshape((1, -1, self.num_kv_heads, self.head_dim)), - v.reshape((1, -1, self.num_kv_heads, self.head_dim)), - v_size, s_size, - ).reshape(q.shape) - else: - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class Qwen2DecoderLayer(nn.Module): - - def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 1000000) - rope_scaling = getattr(config, "rope_scaling", None) - self.self_attn = Qwen2Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - rope_scaling=rope_scaling) - self.mlp = Qwen2MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class Qwen2Model(nn.Module): - - def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: Qwen2DecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config), - prefix=f"{prefix}.layers", - ) - for i, layer in enumerate(self.layers): - layer.self_attn.layer_idx = i - - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class Qwen2ForCausalLM(nn.Module, SupportsLoRA): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - - def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): - raise ValueError("Sliding window for some but all layers is not " - "supported. This model uses sliding window " - "but `max_window_layers` = %s is less than " - "`num_hidden_layers` = %s. Please open an issue " - "to discuss this feature." % ( - config.max_window_layers, - config.num_hidden_layers, - )) - - super().__init__() - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = Qwen2Model(config, cache_config, quant_config) - - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/mtraining/models/sparse_ops/.gitignore b/mtraining/models/sparse_ops/.gitignore deleted file mode 100644 index f15bda0..0000000 --- a/mtraining/models/sparse_ops/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -__pycache__/ -*.egg-info/ -build/ -*.egg -configs/ -minference_attn.py -minference_sparse_index.py \ No newline at end of file diff --git a/mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py b/mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py deleted file mode 100644 index 6803d87..0000000 --- a/mtraining/models/sparse_ops/mtraining_sparse_ops/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .minference_config import get_minference_config -from .minference_attn import minference_flash_attn_func diff --git a/mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py b/mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py deleted file mode 100644 index 83c1682..0000000 --- a/mtraining/models/sparse_ops/mtraining_sparse_ops/minference_config.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import json - - -CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs') -DEFAULT_CONFIG_FILE = "Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json" -if "MI_CONFIG" in os.environ: - DEFAULT_CONFIG_FILE = os.environ["MI_CONFIG"] - - -def get_minference_config(config_file: str = DEFAULT_CONFIG_FILE): - with open(os.path.join(CONFIG_DIR, config_file)) as f: - data = json.loads(f.read()) - config = [] - for layer_data in data: - v_size_list = [None] * len(layer_data) - s_size_list = [None] * len(layer_data) - for k, v in layer_data.items(): - assert v[0] in ['vertical_and_slash', 'flex_vertical_and_slash'] - v_size_list[int(k)] = v[1] - s_size_list[int(k)] = v[2] - config.append([v_size_list, s_size_list]) - return config diff --git a/mtraining/models/sparse_ops/setup.py b/mtraining/models/sparse_ops/setup.py deleted file mode 100644 index c6b202e..0000000 --- a/mtraining/models/sparse_ops/setup.py +++ /dev/null @@ -1,27 +0,0 @@ -import os -import shutil -from setuptools import setup, find_packages - - -setup_dir_path = os.path.dirname(__file__) -mtraining_path = os.path.dirname(os.path.dirname(setup_dir_path)) -setup_dir_path = os.path.join(setup_dir_path, "mtraining_sparse_ops") -cfg_dir_path = os.path.join(mtraining_path, "ops", "minfer", "configs") -op_dir_path = os.path.join(mtraining_path, "ops", "ring_attn", "core") - -shutil.copytree(cfg_dir_path, os.path.join(setup_dir_path, "configs"), dirs_exist_ok=True) - -with open(os.path.join(op_dir_path, "minference_sparse_index.py"), "r") as f: - index_code = f.read() -with open(os.path.join(setup_dir_path, "minference_sparse_index.py"), "w") as f: - f.write(index_code) -with open(os.path.join(op_dir_path, "minference_attn.py"), "r") as f: - attn_code = f.read() -with open(os.path.join(setup_dir_path, "minference_attn.py"), "w") as f: - f.write(attn_code.replace("MTraining.ops.ring_attn.core", "mtraining_sparse_ops")) - -setup( - name="mtraining_sparse_ops", # Name of your project - version="0.1.0", - packages=find_packages(), # Automatically discover all packages -) diff --git a/mtraining/setup.py b/mtraining/setup.py index 19e0346..b1c86ad 100644 --- a/mtraining/setup.py +++ b/mtraining/setup.py @@ -1,11 +1,10 @@ from setuptools import setup, find_packages setup( - name="MTraining", # Name of your project + name="mtraining", # Name of your project version="0.1.0", packages=find_packages(), # Automatically discover all packages install_requires=[], # List dependencies if any (or use requirements.txt) - url="https://github.com/HalberdOfPineapple/MTraining", # Repository URL if applicable classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/mtraining/setup.sh b/mtraining/setup.sh index 863ee5b..957ecbd 100755 --- a/mtraining/setup.sh +++ b/mtraining/setup.sh @@ -8,7 +8,7 @@ if command -v nvidia-smi then # assume base image: amlt-sing/acpt-torch2.3.1-py3.10-cuda12.1-ubuntu22.04 $PIP install ninja cmake wheel pybind11 - $PIP install --no-cache-dir torch==2.3.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + $PIP install --no-cache-dir torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121 $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 $PIP install -r "${BASE_DIR}/requirements.txt" $PIP install git+https://github.com/microsoft/nnscaler.git@2368540417bc3b77b7e714d3f1a0de8a51bb66e8 @@ -17,7 +17,7 @@ then elif command -v rocm-smi then $PIP install ninja cmake wheel pybind11 - $PIP install --no-cache-dir --pre torch==2.3.1+rocm6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0 + $PIP install --no-cache-dir --pre torch==2.3.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0 $PIP install git+https://github.com/OpenAI/triton.git@e192dba#subdirectory=python $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 $PIP install -r "${BASE_DIR}/requirements.txt" @@ -31,4 +31,5 @@ fi NNSCALER_HOME=$(python -c "import nnscaler; print(nnscaler.__path__[0])") echo "export NNSCALER_HOME=${NNSCALER_HOME}" >> ~/.profile echo "export PYTHONPATH=${NNSCALER_HOME}:\${PYTHONPATH}" >> ~/.profile -source ~/.profile \ No newline at end of file +source ~/.profile +pip install -e . \ No newline at end of file diff --git a/mtraining/train.py b/mtraining/train.py index 8590def..5075f54 100644 --- a/mtraining/train.py +++ b/mtraining/train.py @@ -31,14 +31,14 @@ from minference.minference_configuration import MInferenceConfig from minference.configs.model2path import BASE_DIR as SPARSE_PATTERN_CONFIG_DIR -from .attn_funcs import AttnType, overwrite_attn_implementation -from .trainer import CustomTrainer as Trainer, CustomTrainerArgs as TrainerArgs -from .models import MODEL_TO_ATTN_FUNC, MODEL_ID_TO_MODEL_CLS, MODEL_ID_TO_PREFIX +from mtraining.attn_funcs import AttnType, overwrite_attn_implementation +from mtraining.trainer import CustomTrainer as Trainer, CustomTrainerArgs as TrainerArgs +from mtraining.model_configs import MODEL_TO_ATTN_FUNC, MODEL_ID_TO_MODEL_CLS, MODEL_ID_TO_PREFIX -from .utils.expr_data import update_expr_data -from .utils.general import freeze_model_params, load_comm_profile_data -from .utils import chunk_linear_cross_entropy, get_tokenizer, aggregate_outputs_fn, get_resume_path -from .utils.paths import TRAIN_ATTN_CONFIG_DIR, update_expr_data_save_path +from mtraining.utils.expr_data import update_expr_data +from mtraining.utils.paths import update_expr_data_save_path +from mtraining.utils.general import freeze_model_params, load_comm_profile_data +from mtraining.utils import chunk_linear_cross_entropy, get_tokenizer, aggregate_outputs_fn, get_resume_path IGNORE_IDX = -100 logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ def __init__( model_id, config_path: str=None, # merged_ckpt_path: str=None, - active_param_config_name: str=None + active_param_config_path: str=None ): super().__init__() model_cls: PreTrainedModel = MODEL_ID_TO_MODEL_CLS[model_id] @@ -89,8 +89,8 @@ def __init__( config=model_config, ) - if active_param_config_name: - freeze_model_params(self.model, active_param_config_name) + if active_param_config_path: + freeze_model_params(self.model, active_param_config_path) print(f'{__class__.__name__} Self-Attention Class: {self.model.model.layers[0].self_attn.__class__.__name__}') @@ -145,7 +145,7 @@ def __init__( # We still need to attach the function object to the model # otherwise the states of the function will be lost as nnscaler will only load the model from file # but not call this procedure again - from .attn_funcs.minfer_func import MInferAttnFunc + from mtraining.attn_funcs.minfer_func import MInferAttnFunc Attention = self.model.model.layers[0].self_attn.__class__ def update_module(m): if isinstance(m, Attention): @@ -198,7 +198,7 @@ def __init__( config_path=config_path, **kwargs, ) - from minference.dist_ops.op_utils.moba_utils import MoBAConfig + from minference.ops.op_utils.moba_utils import MoBAConfig # -------------------------------------------- print(f"MoBAConfig: {moba_config_dict}") @@ -227,50 +227,50 @@ def update_module(m): } -def load_train_attn_config(train_attn_config_name: str) -> MInferenceConfig: - train_attn_config_path = os.path.join(TRAIN_ATTN_CONFIG_DIR, f'{train_attn_config_name}.yaml') - if not os.path.exists(train_attn_config_path): - print(f"{__name__} | MInference config {train_attn_config_name} not found in {train_attn_config_path}. Use empty minfer config") +def load_train_attn_config(train_attn_config_path: str) -> MInferenceConfig: + if train_attn_config_path is None or train_attn_config_path.lower() == 'none': + train_attn_config_path = None + + if train_attn_config_path is None: + print(f"{__name__} | Use empty Training Attention config") train_attn_config = {} - else: - print(f"{__name__} | MInference config {train_attn_config_name} found in {train_attn_config_path}") + elif os.path.exists(train_attn_config_path): + print(f"{__name__} | Training Attention config found in {train_attn_config_path}.") with open(train_attn_config_path, 'r') as f: train_attn_config = yaml.safe_load(f) print('-' * 20) print("Training Attention Config:") print(train_attn_config) print('-' * 20) + else: + raise FileNotFoundError(f"Training Attention config {train_attn_config_path} not found. Exit.") return train_attn_config def build_model_args(args, train_attn_config: MInferenceConfig) -> Dict: model_args = { 'model_id': args.model_id, 'config_path': args.model_config_path, - "active_param_config_name": args.active_param_config_name, + "active_param_config_path": args.active_param_config_path, } - if args.attn_type == AttnType.MF_MB: + if args.attn_type == AttnType.MINFER: model_args['minfer_config'] = train_attn_config - elif args.attn_type == AttnType.FLEX_PREFILL: - model_args['attn_config'] = train_attn_config elif args.attn_type == AttnType.XATTN: model_args['xattn_params'] = train_attn_config - elif args.attn_type == AttnType.MOBA or args.attn_type == AttnType.ZIGZAG_MOBA: + elif args.attn_type == AttnType.MOBA: model_args['moba_config_dict'] = train_attn_config return model_args def main(args): - update_expr_data_save_path(args.attn_save_path, args.ckpt_save_dir, args.compile_save_path) + update_expr_data_save_path(args.ckpt_save_dir, args.compile_save_path) update_expr_data(args) local_rank = int(os.environ["LOCAL_RANK"]) - if local_rank == 0: - load_comm_profile_data(args) + if local_rank == 0: load_comm_profile_data(args) init_by_attn_type(args.model_id, args.attn_type) - train_attn_config = load_train_attn_config(args.train_attn_config_name) - broadcast_strategy = 'all' + train_attn_config = load_train_attn_config(args.train_attn_config_path) # --------------------------------- # Compute config @@ -345,7 +345,6 @@ def collate_fn(samples): }, ) sampler_config = DatasetSamplerConfig( - # default class: torch.utils.data.distributed.DistributedSampler train_args={ 'shuffle': True, 'seed': args.seed, @@ -384,7 +383,6 @@ def collate_fn(samples): every_n_epochs=args.ckpt_n_epoch, every_n_train_steps=args.ckpt_n_step, save_type='deduped', - # resume_from=(args.resume_from or 'last') if args.check_resume else None, resume_from=args.resume_from, ) @@ -423,8 +421,7 @@ def collate_fn(samples): dataloader=dataloader_config, checkpoint=checkpoint_config, log=[log_config], - - broadcast_strategy=broadcast_strategy, + broadcast_strategy='all', dataset_sampler=sampler_config, transfer_config={ @@ -434,11 +431,7 @@ def collate_fn(samples): merged_ckpt_path=args.resume_merged_ckpt, ) - trainer = Trainer( - train_args=trainer_args, - save_data_steps=args.attn_save_step, - enable_prof=args.enable_prof, - ) + trainer = Trainer(train_args=trainer_args) trainer.run() def print_args(args: argparse.Namespace): @@ -465,28 +458,25 @@ def print_args(args: argparse.Namespace): grad_accu_step = args.global_batch_size // (args.micro_batch_size * scaling_factor) print(f"Scaling Factor (INFERRED):\t{scaling_factor}") print(f"Gradient Accumulation Steps (INFERRED):\t{grad_accu_step}") - print(f"Save Attention Data Every {args.attn_save_step} Steps") print('-' * 40) print(f"Model Config Path:\t{args.model_config_path}") print(f"Dataset path:\t{args.dataset_path}") - print(f'Training Attention Config Name:\t{args.train_attn_config_name}') + print(f'Training Attention Config Path:\t{args.train_attn_config_path}') print(f"Compile Save Path:\t{args.compile_save_path}") - print(f"Attention Save Path:\t{args.attn_save_path}") print(f"Tensorboard Log Path:\t{args.tf_log_dir}") print(f"Checkpoint Save Path:\t{args.ckpt_save_dir}") print(f"Resume from Checkpoint:\t{args.check_resume}") print(f"Path to the checkpoint to resume from:\t{args.resume_from}") print(f"Path to the merged checkpoint to resume from:\t{args.resume_merged_ckpt}") - print(f"Enable profiling: {args.enable_prof}") print(f"Trace Strategy:\t{args.trace_strategy}") if args.transfer_config_dir: print(f"Transfer Configs from another experiment:\t{args.transfer_config_dir}") print(f"Force Transfer Configs:\t{args.transfer_force}") - if args.active_param_config_name: - print(f"Active Param Config Name:\t{args.active_param_config_name}") + if args.active_param_config_path: + print(f"Active Param Config Path:\t{args.active_param_config_path}") if args.ckpt_n_step: print(f"Checkpoint Save Every {args.ckpt_n_step} Steps") @@ -518,25 +508,22 @@ def print_args(args: argparse.Namespace): parser.add_argument('--model_id', type=str, default='microsoft/Phi-3-mini-4k-instruct', help='transformers model id') parser.add_argument('--model_config_path', type=str, default=None, help='path to the model config') - parser.add_argument('-s', '--attn_save_step', type=int, default=1, help='Save attention data every n steps') - parser.add_argument('--train_attn_config_name', type=str, default=None, help='Name of Minference config file') + parser.add_argument('--train_attn_config_path', type=str, default=None, help='Name of Minference config file') parser.add_argument('--compile_save_path', type=str, default='./.nnscaler', help='path to save compiled code') - parser.add_argument('--attn_save_path', type=str, default=None, help='path to save attention data') + parser.add_argument('--tf_log_dir', type=str, default=None, help='path to save tensorboard logs') parser.add_argument('--dataset_path', type=str, default=None, help='path to the dataset') parser.add_argument('--check_resume', action='store_true', help='whether to resume from checkpoint') parser.add_argument('--resume_from', type=str, default=None, help='path to the checkpoint to resume from') parser.add_argument('--resume_merged_ckpt', type=str, default=None, help='path (dir) to the merged checkpoint to resume from') - parser.add_argument('--enable_prof', action='store_true', help='enable profiling') - parser.add_argument('--ckpt_save_dir', type=str, default=None, help='path to save checkpoints') parser.add_argument('--ckpt_n_epoch', type=int, default=1, help='save checkpoint every n epochs') parser.add_argument('--ckpt_n_step', type=int, default=0, help='save checkpoint every n steps') parser.add_argument('--transfer_config_dir', type=str, default="none", help='path to transfer configs from another experiment') parser.add_argument('--transfer_force', action='store_true', help='force transfer configs') - parser.add_argument('--active_param_config_name', type=str, default=None, help='path to the active param list') + parser.add_argument('--active_param_config_path', type=str, default=None, help='path to the active param list') parser.add_argument('-p', '--disable_progressbar', action='store_true', help='transformers model id',) @@ -553,9 +540,7 @@ def print_args(args: argparse.Namespace): if args.n_iter <= 0: args.n_iter = None if args.n_epochs <= 0: args.n_epochs = None - if args.train_attn_config_name is None or args.train_attn_config_name.lower() == 'none': args.train_attn_config_name = None if args.transfer_config_dir.lower() == 'none': args.transfer_config_dir = None - if args.active_param_config_name.lower() == 'none': args.active_param_config_name = None # set a new field of args 'args.orig_resume_from' to store the original resume_from value args.orig_resume_from = args.resume_from diff --git a/mtraining/trainer.py b/mtraining/trainer.py index ea45b64..ed03cc1 100644 --- a/mtraining/trainer.py +++ b/mtraining/trainer.py @@ -24,9 +24,9 @@ Trainer, _StepStat, TrainerArgs, TrainStatus, AggregatedTrainHook, TrainHook ) -from .utils.paths import EXPR_DATA_SAVE_PATH -from .utils.general import fix_model_state_dict -from .utils.custom_parallel import parallelize as custom_parallelize +from mtraining.utils.paths import EXPR_DATA_SAVE_PATH +from mtraining.utils.general import fix_model_state_dict +from mtraining.utils.custom_parallel import parallelize as custom_parallelize logger = logging.getLogger(__name__) diff --git a/mtraining/utils/general.py b/mtraining/utils/general.py index 28ec24f..955a38e 100644 --- a/mtraining/utils/general.py +++ b/mtraining/utils/general.py @@ -7,7 +7,8 @@ from nnscaler.cli.trainer_args import AggregatedOutputs from nnscaler.runtime.module import ParallelModule -from .paths import ACTIVE_PARAM_CONFIG_DIR, BASE_DIR + +from .paths import BASE_DIR import logging logger = logging.getLogger(__name__) @@ -103,11 +104,6 @@ def is_active(module_name: str, keep_active: List[str]): return True return False -def read_active_param_list(active_param_config_name: str): - print(f"Reading active param list from {active_param_config_name}...") - with open(os.path.join(ACTIVE_PARAM_CONFIG_DIR, f'{active_param_config_name}.txt'), "r") as f: - return f.read().splitlines() - def freeze_model_params_(model, keep_active: List[str], prefix=""): if dist.get_rank() == 0: print("-" * 80) @@ -129,10 +125,10 @@ def freeze_model_params_(model, keep_active: List[str], prefix=""): print("-" * 80) -def freeze_model_params(model, active_param_config_name: str, prefix=""): - print(f"active param config name: {active_param_config_name}") - keep_active = read_active_param_list(active_param_config_name) - print(f"keep active: {keep_active}") +def freeze_model_params(model, active_param_config_path: str, prefix=""): + with open(active_param_config_path, "r") as f: + keep_active = f.read().splitlines() + print(f"freeze_model_params | keep active: {keep_active}") freeze_model_params_(model, keep_active, prefix) diff --git a/mtraining/utils/paths.py b/mtraining/utils/paths.py index 8035f50..716e7d4 100644 --- a/mtraining/utils/paths.py +++ b/mtraining/utils/paths.py @@ -1,8 +1,6 @@ import os BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -TRAIN_ATTN_CONFIG_DIR = os.path.join(BASE_DIR, 'train_attn_configs') -ACTIVE_PARAM_CONFIG_DIR = os.path.join(BASE_DIR, "models", "active_param_configs") EXPR_DATA_SAVE_PATH = { From 4a2f1cb95912c2332495183e6d6006bc6239395b Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Fri, 13 Jun 2025 09:25:05 +0000 Subject: [PATCH 04/12] Restructured operators and fix data processing logic --- .../dist_ops/minfer_dr_stripe_triton.py | 4 +- minference/dist_ops/minfer_dr_striped.py | 6 +- minference/dist_ops/minfer_striped.py | 6 +- minference/dist_ops/minfer_striped_triton.py | 4 +- minference/dist_ops/minfer_zigzag.py | 6 +- minference/dist_ops/xattn_zigzag.py | 6 +- minference/ops/moba.py | 2 +- minference/ops/op_utils/moba_utils.py | 2 - .../ops/op_utils/vertical_slash_utils.py | 822 +++++++++++++++ minference/ops/op_utils/xattn_utils.py | 2 - ...tn.py => pit_sparse_flash_attention_v3.py} | 160 +-- ...> pit_sparse_flash_attention_v3_triton.py} | 158 +-- minference/ops/utils.py | 988 ------------------ minference/ops/xattention_fa.py | 2 +- mtraining/attn_funcs/minfer_func.py | 4 +- .../scripts/prolong_data_prepare.sh | 23 + .../scripts/train_qwen_mini_ProLong512K.sh | 2 +- mtraining/requirements.txt | 11 +- mtraining/utils/data_utils/prolong.py | 122 +++ 19 files changed, 993 insertions(+), 1337 deletions(-) create mode 100644 minference/ops/op_utils/vertical_slash_utils.py rename minference/ops/{minference_attn.py => pit_sparse_flash_attention_v3.py} (76%) rename minference/ops/{minference_attn_triton.py => pit_sparse_flash_attention_v3_triton.py} (86%) create mode 100644 mtraining/experiments/scripts/prolong_data_prepare.sh create mode 100644 mtraining/utils/data_utils/prolong.py diff --git a/minference/dist_ops/minfer_dr_stripe_triton.py b/minference/dist_ops/minfer_dr_stripe_triton.py index c9b6617..7c8ea3c 100644 --- a/minference/dist_ops/minfer_dr_stripe_triton.py +++ b/minference/dist_ops/minfer_dr_stripe_triton.py @@ -9,8 +9,8 @@ shuffle_striped_input, recover_striped_output, get_inner_ring, get_outer_ring ) -from minference.ops.utils import build_index, convert_blockmask -from minference.ops.minference_attn_triton import block_bar_attn_fwd, block_bar_attn_bwd +from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd, block_bar_attn_bwd def minfer_dr_stripe_triton_forward_inner( process_group: dist.ProcessGroup, diff --git a/minference/dist_ops/minfer_dr_striped.py b/minference/dist_ops/minfer_dr_striped.py index c404c88..ba8e59e 100644 --- a/minference/dist_ops/minfer_dr_striped.py +++ b/minference/dist_ops/minfer_dr_striped.py @@ -9,9 +9,9 @@ get_inner_ring, get_outer_ring ) -from minference.ops.minference_attn_triton import block_bar_attn_fwd -from minference.ops.minference_attn import block_attn_bwd, bar_attn_bwd -from minference.ops.utils import build_index, extract_kv, merge_kv, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_bwd, bar_attn_bwd +from minference.ops.op_utils.vertical_slash_utils import build_index, extract_kv, merge_kv, convert_blockmask def minfer_dr_stripe_forward_inner( diff --git a/minference/dist_ops/minfer_striped.py b/minference/dist_ops/minfer_striped.py index ffd7cb1..76b7cb3 100644 --- a/minference/dist_ops/minfer_striped.py +++ b/minference/dist_ops/minfer_striped.py @@ -9,9 +9,9 @@ RingComm, shuffle_striped_input, recover_striped_output, ) -from minference.ops.minference_attn_triton import block_bar_attn_fwd -from minference.ops.minference_attn import block_attn_bwd, bar_attn_bwd -from minference.ops.utils import build_index, convert_blockmask, extract_kv, merge_kv +from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_bwd, bar_attn_bwd +from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask, extract_kv, merge_kv if torch.version.hip is None: original_flags = sys.getdlopenflags() diff --git a/minference/dist_ops/minfer_striped_triton.py b/minference/dist_ops/minfer_striped_triton.py index 77dda4a..2946c3f 100644 --- a/minference/dist_ops/minfer_striped_triton.py +++ b/minference/dist_ops/minfer_striped_triton.py @@ -7,8 +7,8 @@ RingComm, shuffle_striped_input, recover_striped_output, ) -from minference.ops.utils import build_index, convert_blockmask -from minference.ops.minference_attn_triton import block_bar_attn_fwd, block_bar_attn_bwd +from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd, block_bar_attn_bwd def minfer_stripe_triton_forward( diff --git a/minference/dist_ops/minfer_zigzag.py b/minference/dist_ops/minfer_zigzag.py index 88292aa..d519cd1 100644 --- a/minference/dist_ops/minfer_zigzag.py +++ b/minference/dist_ops/minfer_zigzag.py @@ -8,9 +8,9 @@ RingComm, shuffle_zigzag_input, recover_zigzag_output, ) -from minference.ops.utils import build_index, convert_blockmask -from minference.ops.minference_attn_triton import block_bar_attn_fwd -from minference.ops.minference_attn import block_attn_bwd, bar_attn_bwd +from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_bwd, bar_attn_bwd def minfer_zigzag_forward( process_group: dist.ProcessGroup, diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py index ed03719..c201abc 100644 --- a/minference/dist_ops/xattn_zigzag.py +++ b/minference/dist_ops/xattn_zigzag.py @@ -11,11 +11,11 @@ shuffle_block_mask_zigzag, ) -from minference.ops.utils import convert_blockmask from minference.ops.op_utils.xattn_utils import LN2, find_blocks_chunked -from minference.ops.minference_attn import block_attn_fwd, block_attn_bwd -from minference.ops.minference_attn_triton import triton_block_attn_fwd, triton_block_attn_bwd +from minference.ops.op_utils.vertical_slash_utils import convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd from minference.ops.xattention_fa import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum +from minference.ops.pit_sparse_flash_attention_v3_triton import triton_block_attn_fwd, triton_block_attn_bwd def xattn_zigzag_estimate( diff --git a/minference/ops/moba.py b/minference/ops/moba.py index 7743eeb..fef02e5 100644 --- a/minference/ops/moba.py +++ b/minference/ops/moba.py @@ -10,7 +10,7 @@ _flash_attn_varlen_backward, ) -from .utils import calc_chunks +from .op_utils.moba_utils import calc_chunks def hf_to_fa(x: torch.Tensor): """ diff --git a/minference/ops/op_utils/moba_utils.py b/minference/ops/op_utils/moba_utils.py index 8cf91d1..99d73f6 100644 --- a/minference/ops/op_utils/moba_utils.py +++ b/minference/ops/op_utils/moba_utils.py @@ -1,7 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention import torch import torch.distributed as dist diff --git a/minference/ops/op_utils/vertical_slash_utils.py b/minference/ops/op_utils/vertical_slash_utils.py new file mode 100644 index 0000000..57a8d23 --- /dev/null +++ b/minference/ops/op_utils/vertical_slash_utils.py @@ -0,0 +1,822 @@ +import os +from typing import List + +import torch +import torch.distributed as dist + +import triton +import triton.language as tl + + +@triton.jit +def _triton_extract_kv_kernel( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + stride_lz, stride_ln, stride_lh, stride_ld, + stride_bz, stride_bn, stride_bh, stride_bd, + stride_iz, stride_ih, stride_in, + stride_cz, stride_ch, stride_cr, + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_n = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + min_n = tl.load(v_cnt_ptr + step * stride_cr) + max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) + start_n = start_n * BLOCK_N + end_n = start_n + BLOCK_N + if start_n >= max_n or end_n <= min_n: + return + + offs_d = tl.arange(0, BLOCK_D) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = (offs_n >= min_n) & (offs_n < max_n) + + v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + + # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens + idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens + k = tl.load(local_k_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) + v = tl.load(local_v_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) + tl.store(bar_k_ptr + offs_n[:, None] * stride_bn, k, mask=mask_n[:, None]) + tl.store(bar_v_ptr + offs_n[:, None] * stride_bn, v, mask=mask_n[:, None]) + + +def extract_kv( + local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + step: int, +): + batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape + _, num_tokens, num_kv_heads, _ = local_k.shape + block_N = 128 + block_D = head_dim + _triton_extract_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), + bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N=block_N, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + + +@triton.jit +def _triton_merge_kv_kernel( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + stride_lz, stride_ln, stride_lh, stride_ld, + stride_bz, stride_bn, stride_bh, stride_bd, + stride_iz, stride_ih, stride_in, + stride_cz, stride_ch, stride_cr, + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_n = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + min_n = tl.load(v_cnt_ptr + step * stride_cr) + max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) + start_n = start_n * BLOCK_N + end_n = start_n + BLOCK_N + if start_n >= max_n or end_n <= min_n: + return + + offs_d = tl.arange(0, BLOCK_D) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = (offs_n >= min_n) & (offs_n < max_n) + + v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + + # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens + idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens + k = tl.load(bar_k_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_k.type.element_ty) + v = tl.load(bar_v_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_v.type.element_ty) + tl.atomic_add(local_k_ptr + idx[:, None] * stride_ln, k, mask=mask_n[:, None], sem="relaxed") + tl.atomic_add(local_v_ptr + idx[:, None] * stride_ln, v, mask=mask_n[:, None], sem="relaxed") + + +def merge_kv( + local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + step: int, +): + batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape + _, num_tokens, num_kv_heads, _ = local_k.shape + block_N = 128 + block_D = head_dim + _triton_merge_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), + bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N=block_N, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + + +# triton.cdiv(world_size * num_blocks, BLOCK_N), num_heads, batch_size +# block_mask: [batch_size, num_heads, num_blocks_global] +@triton.jit +def _calc_block_mask_kernel( + s_idx, block_mask, + stride_sz, stride_sh, stride_sk, + stride_bz, stride_bh, stride_bn, + max_s_size, num_tokens, granularity, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + group_idx = tl.program_id(0) + + block_offs = tl.arange(0, BLOCK_N) + slash_offs = tl.arange(0, BLOCK_K) + + s_idx_ptr = s_idx + batch_idx * stride_sz + head_idx * stride_sh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx = group_idx * BLOCK_N + block_offs + + blocks = tl.zeros([BLOCK_N], dtype=tl.uint8) + for s_off in range(0, max_s_size, BLOCK_K): + s = tl.load(s_idx_ptr + (s_off + slash_offs) * stride_sk) + left = (num_tokens - granularity - s) // granularity + right = (num_tokens - 1 - s) // granularity + + # mask is generated by checking if a block's index falls between the calculated ranges + blocks |= tl.max((block_idx[None, :] >= left[:, None]) & (block_idx[None, :] <= right[:, None]), 0).to(tl.uint8) + + b_mask = (group_idx * BLOCK_N + block_offs) * granularity < num_tokens + tl.store(block_mask_ptr + (group_idx * BLOCK_N + block_offs) * stride_bn, blocks, mask=b_mask) + + +@triton.jit +def _striped_convert_indices_kernel( + last_row_mask, v_idx, v_cnt, + block_mask, bar_idx, bar_pos, bar_cnt, + stride_rz, stride_rh, stride_rn, + stride_vz, stride_vh, stride_vk, + stride_nz, stride_nh, stride_nt, + stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, + stride_iz, stride_ih, stride_im, stride_ik, + stride_cz, stride_ch, stride_cm, stride_ct, + max_v_size, num_blocks, granularity, world_size, rank, + BLOCK_N: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + block_idx_q_local = tl.program_id(0) + + block_idx_q_global = block_idx_q_local * world_size + rank + + num_tokens_local = num_blocks * granularity + num_blocks_global = world_size * num_blocks + shift = num_blocks_global - 1 - block_idx_q_global + + block_offs = tl.arange(0, BLOCK_N) + + last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh + v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh + v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm + bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm + + cnt_valid = 0 + cnt_all = 0 + v_off = 0 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + for step in range(world_size): + for block_off_k in range(0, num_blocks, BLOCK_N): + block_idx_k_local = block_off_k + block_offs + block_idx_k_global = (block_off_k + block_offs) * world_size + step + mask_local = tl.load( + last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, + mask=(block_idx_k_global + shift < num_blocks_global), + other=0, + ) + tl.store( + block_mask_ptr + block_idx_k_local * stride_bn, + mask_local, + mask=(block_idx_k_local < num_blocks), + ) + block_left = v_off + block_idx_k_local * granularity + block_right = block_left + granularity + max_blocks = block_idx_q_local + 1 if step <= rank else block_idx_q_local + v_max = v_off + min(block_off_k + BLOCK_N, max_blocks) * granularity + while v < v_max and cnt_all < max_v_size: + if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): + tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) + tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) + cnt_valid += 1 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + block_mask_ptr += stride_bt + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + v_off += num_tokens_local + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + +@triton.jit +def _zigzag_convert_indices_kernel( + last_row_mask, v_idx, v_cnt, + block_mask, bar_idx, bar_pos, bar_cnt, + stride_rz, stride_rh, stride_rn, + stride_vz, stride_vh, stride_vk, + stride_nz, stride_nh, stride_nt, + stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, + stride_iz, stride_ih, stride_im, stride_ik, + stride_cz, stride_ch, stride_cm, stride_ct, + max_v_size, num_blocks, granularity, world_size, rank, + BLOCK_N: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + block_idx_q_local = tl.program_id(0) + + if rank < world_size // 2: + revert_rank = rank * 2 + else: + revert_rank = (world_size - 1 - rank) * 2 + 1 + if block_idx_q_local < num_blocks // 2: + block_idx_q_global = revert_rank * (num_blocks // 2) + block_idx_q_local + else: + block_idx_q_global = (world_size * 2 - 1 - revert_rank) * (num_blocks // 2) + block_idx_q_local - (num_blocks // 2) + + num_blocks_global = world_size * num_blocks + shift = num_blocks_global - 1 - block_idx_q_global + + block_offs = tl.arange(0, BLOCK_N) + + last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh + v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh + v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm + bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm + + cnt_valid = 0 + cnt_all = 0 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + for step in range(world_size): + v_off = step * num_blocks * granularity + v_end = v_off + num_blocks * granularity + for block_off_k in range(0, num_blocks, BLOCK_N): + block_idx_k_local = block_off_k + block_offs + # assert BLOCK_N <= num_blocks // 2 + if block_off_k < num_blocks // 2: + v_off_global = step * (num_blocks // 2) * granularity + block_idx_k_global = step * (num_blocks // 2) + block_idx_k_local + else: + v_off_global = (world_size * 2 - 2 - step) * (num_blocks // 2) * granularity + block_idx_k_global = (world_size * 2 - 1 - step) * (num_blocks // 2) + block_idx_k_local - (num_blocks // 2) + mask_local = tl.load( + last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, + mask=(block_idx_k_global + shift < num_blocks_global), + other=0, + ) + tl.store( + block_mask_ptr + block_idx_k_local * stride_bn, + mask_local, + mask=(block_idx_k_local < num_blocks), + ) + # block_left = block_idx_k_global * granularity - v_off_global + v_off + # block_right = block_left + granularity + block_left = v_off + block_idx_k_local * granularity + block_right = block_left + granularity + v_max = (block_idx_q_global + 1) * granularity - v_off_global + v_off + while v < v_end and cnt_all <= max_v_size: + if v < v_max: + if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): + tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) + tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) + cnt_valid += 1 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + block_mask_ptr += stride_bt + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + +def convert_indices( + v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] + s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] + world_size: int, + rank: int, + num_blocks: int, + granularity: int, + num_tokens: int = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, +): + num_blocks_global = world_size * num_blocks + if num_tokens is None: + # Note that for each invokation of `convert_indices`, `num_tokens` is None and becomes the **global number of tokens** + num_tokens = num_blocks_global * granularity + batch_size, num_heads, max_v_size = v_idx.shape + batch_size, num_heads, max_s_size = s_idx.shape + last_row_mask = torch.zeros((batch_size, num_heads, num_blocks_global), dtype=torch.bool, device=s_idx.device) + + BLOCK_N, BLOCK_K = 128, 128 + assert max_s_size <= BLOCK_K * BLOCK_K, f"max_s_size={max_s_size} > BLOCK_K * BLOCK_K={BLOCK_K * BLOCK_K}" + _calc_block_mask_kernel[(triton.cdiv(num_blocks_global, BLOCK_N), num_heads, batch_size)]( + s_idx, last_row_mask, + s_idx.stride(0), s_idx.stride(1), s_idx.stride(2), + last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), + max_s_size, num_tokens, granularity, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + num_warps=4, num_stages=2, + ) + + block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) + bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) + bar_cnt = torch.empty((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) + v_cnt = torch.empty((batch_size, num_heads, world_size + 1), dtype=torch.int32, device=v_idx.device) + bar_pos = torch.zeros_like(bar_idx) + if zigzag_transform: + convert_indices_kernel = _zigzag_convert_indices_kernel + assert num_blocks % 2 == 0 + BLOCK_N = max(num_blocks // 2, 128) + else: + convert_indices_kernel = _striped_convert_indices_kernel + BLOCK_N = 128 + convert_indices_kernel[(num_blocks, num_heads, batch_size)]( + last_row_mask, v_idx, v_cnt, block_mask, bar_idx, bar_pos, bar_cnt, + last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3), block_mask.stride(4), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + max_v_size, num_blocks, granularity, world_size, rank, BLOCK_N=BLOCK_N, + num_warps=1, num_stages=1, + ) + + return block_mask, bar_idx, bar_cnt, bar_pos, v_cnt + + +def _torch_convert_indices( + v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] + s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] + world_size: int, + rank: int, + num_blocks: int, + granularity: int, +): + batch_size, num_heads, max_v_size = v_idx.shape + num_tokens = world_size * num_blocks * granularity + block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) + bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) + bar_cnt = torch.zeros((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) + for batch_idx in range(batch_size): + for head_idx in range(num_heads): + for block_idx_q in range(num_blocks): + block_idx_q_global = block_idx_q * world_size + rank + cnt_all, cnt_valid = 0, 0 + for step in range(world_size): + for block_idx_k in range(block_idx_q + 1): + block_idx_k_global = block_idx_k * world_size + step + s_min = max((block_idx_q_global - block_idx_k_global - 1) * granularity, 0) + s_max = (block_idx_q_global - block_idx_k_global + 1) * granularity + flag = torch.any((s_idx[batch_idx, head_idx] > s_min) & (s_idx[batch_idx, head_idx] < s_max)) + block_mask[step, batch_idx, head_idx, block_idx_q, block_idx_k] = flag + v_min = (step * num_blocks + block_idx_k) * granularity + max_blocks = block_idx_q + 1 if step <= rank else block_idx_q + v_max = (step * num_blocks + min(block_idx_k + 1, max_blocks)) * granularity + while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_min: + cnt_all += 1 + while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_max: + if not flag: + bar_idx[batch_idx, head_idx, block_idx_q, cnt_valid] = \ + v_idx[batch_idx, head_idx, cnt_all] - step * num_blocks * granularity + cnt_valid += 1 + cnt_all += 1 + bar_cnt[batch_idx, head_idx, block_idx_q, step + 1] = cnt_valid + return block_mask, bar_idx, bar_cnt + + + +def sum_all_diagonal_matrix(mat: torch.Tensor): + b, h, m, n = mat.shape + + # Pads the matrix on left and right (on the last dimension) + mat_padded = torch.nn.functional.pad(mat, (m, m), "constant", 0.) # shape: [b, h, m, 2 * m + n] + # Change the strides + mat_strided = mat_padded.as_strided((b, h, m, m + n), (m * (2 * m + n) * h, m * (2 * m + n), 2 * m + n + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 2) # shape: [b, h, m + n] + return sum_diags[:, :, 1:].contiguous() + +def calc_index( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + last_q_size: int = 64, + sink_tokens: int = 30, + sliding_window: int = 100, + group: dist.group = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, + granularity: int = 128, +): + # TODO: adapt naturely striped inputs + # TODO: flex-prefill (top-P) + # TODO: reduce bubble + # TODO: support total_num_tokens % world_size != 0 + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): + flex_prefill = True + v_p = [x[0] for x in v_size] + v_size = [x[1] for x in v_size] + s_p = [x[0] for x in s_size] + s_size = [x[1] for x in s_size] + else: + flex_prefill = False + assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) + + max_v_size = min(triton.cdiv(max(v_size), 128), num_tokens // 128) * 128 + max_s_size = min(triton.cdiv(max(s_size), 128), num_tokens // 128) * 128 + + last_rank = world_size - 1 + if rank == last_rank: + last_q = q[:, -last_q_size:, :, :].detach().clone().reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)) + else: + last_q = torch.zeros((batch_size, last_q_size, num_kv_heads, num_qo_heads // num_kv_heads, head_dim), device=q.device, dtype=q.dtype) + + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking broadcast last_q from rank={last_rank}", flush=True) + dist.broadcast(last_q, src=last_rank, group=group, async_op=False) + + qk = torch.einsum('bmghd, bngd -> bghmn', last_q, k) * (k.shape[-1] ** -0.5) + qk = qk.reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) + + if rank == last_rank: + # Causal Mask, requires num_tokens // world_size >= last_q + arange = torch.arange(last_q_size, device=k.device) + mask = arange[None, None, :, None] >= arange[None, None, None, :] + qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) + if flex_prefill: # qk = torch.softmax(qk, dim=-1) / last_q_size + qk_max = torch.max(qk, dim=-1, keepdim=True).values + qk_max_list = [torch.empty_like(qk_max) for _ in range(world_size)] + dist.all_gather(qk_max_list, qk_max, group=group, async_op=False) + qk_max = torch.max(torch.stack(qk_max_list), dim=0).values + qk = torch.exp(qk - qk_max) + qk_sum = torch.sum(qk, dim=-1, keepdim=True) + qk_sum_list = [torch.empty_like(qk_sum) for _ in range(world_size)] + dist.all_gather(qk_sum_list, qk_sum, group=group, async_op=False) + qk_sum = torch.sum(torch.stack(qk_sum_list), dim=0) + qk /= (qk_sum * last_q_size) + + v_gather_rank = 0 + vertical = qk.sum(-2, keepdim=False) # [B, H, N_LOCAL] + if rank == 0 and not flex_prefill: + vertical[..., :sink_tokens] = torch.inf + if rank == v_gather_rank: + gathered_vertical = [torch.empty_like(vertical) for _ in range(world_size)] + else: + gathered_vertical = None + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking gather vertical to {v_gather_rank}", flush=True) + dist.gather(vertical, gathered_vertical, dst=v_gather_rank, group=group, async_op=False) + + if rank == v_gather_rank: + vertical: torch.Tensor = torch.cat(gathered_vertical, dim=-1) + if stripe_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, -1, world_size, granularity)) + vertical = vertical.swapaxes(2, 3) + vertical = vertical.reshape((batch_size, num_qo_heads, -1)) + elif zigzag_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, 2, world_size, -1)) + chunks = [] + for step in range(world_size): + chunks.append(vertical[:, :, 0, step]) + chunks.append(vertical[:, :, 1, world_size - 1 - step]) + vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) + + v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) + v_indices = v_topk.indices.to(torch.int32) + if flex_prefill: + v_cumsum = v_topk.values.cumsum_(dim=-1) + v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + v_size = torch.tensor(v_size, device=k.device)[None, :, None] + v_arange = torch.arange(max_v_size, device=k.device) + v_indices.masked_fill_(v_arange[None, None, :] >= v_size, num_tokens * world_size) + v_indices = v_indices.sort(dim=-1, descending=False).values + else: + v_indices = torch.empty((batch_size, num_qo_heads, max_v_size), dtype=torch.int32, device=k.device) + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking broadcast v_indices from rank={v_gather_rank}", flush=True) + dist.broadcast(v_indices, src=v_gather_rank, group=group, async_op=False) # async + + s_gather_rank = 0 + slash = sum_all_diagonal_matrix(qk) # shape: [B, H, N_LOCAL + LAST_Q_SIZE - 1] + if rank == world_size - 1 and not flex_prefill: + # -> index starting from the left bottom corner to right upper corner + # (sliding_window) from -(last_q_size-1) is the sliding window close to the main diagonal + slash[..., -(last_q_size - 1 + sliding_window):] = torch.inf + + + if rank == s_gather_rank: + gathered_slash = [torch.empty_like(slash) for _ in range(world_size)] + else: + gathered_slash = None + + if os.getenv("COMM_DEBUG", False): + # For debugging purposes, print the rank and tensor shapes + rank = dist.get_rank(group) + print(f"Rank {rank} | calc_index | before invoking gather slash to rank=0", flush=True) + dist.gather(slash, gathered_slash, dst=s_gather_rank, group=group, async_op=False) + + if rank == s_gather_rank: + slash = gathered_slash[0] + for next_slash in gathered_slash[1:]: + slash[..., -last_q_size + 1:] += next_slash[..., :last_q_size - 1] + slash = torch.cat((slash, next_slash[..., last_q_size - 1:]), dim=-1) + + # slash presents the sum of attention from 0-th to (num_tokens_global - last_q_size - 1), where 0 represents the diagonal at bottom left corner + slash = slash[..., :-last_q_size + 1] + s_topk = torch.topk(slash, max_s_size, -1, sorted=True) + + # s_indices contain indices starting from the right upper corner to left bottom corner + s_indices = (num_tokens * world_size - 1) - s_topk.indices.to(torch.int32) + if flex_prefill: + s_cumsum = s_topk.values.cumsum_(dim=-1) + s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + s_size = torch.tensor(s_size, device=k.device)[None, :, None] + s_arange = torch.arange(max_s_size, device=k.device) + s_indices.masked_fill_(s_arange[None, None, :] >= s_size, -1) + s_indices = s_indices.sort(dim=-1, descending=True).values + else: + s_indices = torch.empty((batch_size, num_qo_heads, max_s_size), dtype=torch.int32, device=k.device) + dist.broadcast(s_indices, src=s_gather_rank, group=group, async_op=False) + + return v_indices.to(torch.int32), s_indices.to(torch.int32) + +def calc_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + last_q_size: int = 64, + sink_tokens: int = 30, + sliding_window: int = 100, + group: dist.group = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, + granularity: int = 128, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + + if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): + flex_prefill = True + v_p = [x[0] for x in v_size] + v_size = [x[1] for x in v_size] + s_p = [x[0] for x in s_size] + s_size = [x[1] for x in s_size] + else: + flex_prefill = False + assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) + + qk = torch.einsum( + f'bmghd, bngd -> bghmn', + q[:, -last_q_size:, :, :].reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)), + k, + ).reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) * (head_dim ** -0.5) + + arange = torch.arange(last_q_size, device=k.device) + mask = arange[None, None, :, None] >= arange[None, None, None, :] + qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) + if flex_prefill: + qk = torch.softmax(qk, dim=-1) / last_q_size + + max_v_size = min(max(v_size), num_tokens) + max_v_size = triton.cdiv(max_v_size, 128) * 128 + vertical = qk.sum(-2, keepdim=False) + if not flex_prefill: + vertical[..., :sink_tokens] = torch.inf + if stripe_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, -1, dist.get_world_size(group), granularity)) + vertical = vertical.swapaxes(2, 3) + vertical = vertical.reshape((batch_size, num_qo_heads, -1)) + elif zigzag_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, 2, dist.get_world_size(group), -1)) + chunks = [] + for step in range(dist.get_world_size(group)): + chunks.append(vertical[:, :, 0, step]) + chunks.append(vertical[:, :, 1, dist.get_world_size(group) - 1 - step]) + vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) + v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) + v_indices = v_topk.indices + if flex_prefill: + v_cumsum = v_topk.values.cumsum_(dim=-1) + v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + v_size = torch.tensor(v_size, device=k.device)[None, :, None] + + max_s_size = min(max(s_size), num_tokens) + max_s_size = triton.cdiv(max_s_size, 128) * 128 + slash = sum_all_diagonal_matrix(qk)[..., :-last_q_size + 1] + if not flex_prefill: + slash[..., -sliding_window:] = torch.inf + s_topk = torch.topk(slash, max_s_size, -1, sorted=True) + s_indices = (num_tokens - 1) - s_topk.indices + if flex_prefill: + s_cumsum = s_topk.values.cumsum_(dim=-1) + s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + s_size = torch.tensor(s_size, device=k.device)[None, :, None] + + v_arange = torch.arange(max_v_size, device=k.device) + v_idx = v_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) + v_idx.masked_fill_(v_arange[None, None, :] >= v_size, 2147483647) + v_idx = v_idx.sort(dim=-1, descending=False).values + + s_arange = torch.arange(max_s_size, device=k.device) + s_idx = s_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) + s_idx.masked_fill_(s_arange[None, None, :] >= s_size, -1) + s_idx = s_idx.sort(dim=-1, descending=True).values + + return v_idx, s_idx + +def build_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + if type(v_size) is list: + assert len(v_size) == q.shape[2] + assert len(s_size) == q.shape[2] + v_idx, s_idx = calc_index_local(q, k, v_size, s_size, last_q_size=64) + else: + v_idx, s_idx = v_size, s_size + + num_blocks = triton.cdiv(num_tokens, granularity) + block_mask, bar_idx, bar_cnt, _, _ = convert_indices(v_idx, s_idx, world_size, rank, num_blocks, granularity) + block_mask = block_mask[rank] + return block_mask, bar_idx, bar_cnt + +def build_index( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, # num_tokens_local + granularity: int, + stripe_transform: bool = True, + zigzag_transform: bool = False, + group: dist.group = None, +): + """ + Input: (all inputs correspond to the local part for each rank) + q: shape [batch_size, num_tokens_local, num_qo_heads, head_dim] + k: shape [batch_size, num_tokens_local, num_kv_heads, head_dim] + v_size: shape [num_qo_heads] + s_size: shape [num_qo_heads] + num_tokens: number of tokens in the local part of QK + Returns: + block_mask: shape [world_size, batch_size, num_heads, num_blocks, num_blocks] + bar_idx: shape [batch_size, num_heads, num_blocks, max_v_size] + bar_cnt: shape [batch_size, num_heads, num_blocks, world_size + 1], each entry is the cumulative number of selected bars corresponding a rank + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + if isinstance(v_size, list): + v_idx, s_idx = calc_index( + q, k, v_size, s_size, last_q_size=64, group=group, + stripe_transform=stripe_transform, + zigzag_transform=zigzag_transform, + granularity=granularity + ) + else: + v_idx, s_idx = v_size, s_size + + num_blocks = triton.cdiv(num_tokens, granularity) # num_blocks_local + + # Note that block_mask is a 5D tensor with shape [world_size, batch_size, num_heads, num_blocks, num_blocks] + # with each block_mask[i] is to a mask corresponding the num_tokens_local x num_tokens_local matmul for each step + block_mask, bar_idx, bar_cnt, bar_pos, v_cnt = convert_indices( + v_idx, s_idx, world_size, rank, num_blocks, granularity, + stripe_transform=stripe_transform, + zigzag_transform=zigzag_transform, + ) + return block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt + + +def _build_mask_local( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + with torch.no_grad(): + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity, world_size, rank) + batch_size, num_tokens, num_heads, head_dim = q.shape + num_blocks = block_mask.shape[-1] + num_tokens_pad = num_blocks * granularity + # Block Mask + mask = block_mask.unsqueeze(3).unsqueeze(5).repeat((1, 1, 1, granularity, 1, granularity)) + mask = mask.reshape((batch_size, num_heads, num_tokens_pad, num_tokens_pad)) + # Bar Mask + for batch_idx in range(batch_size): + for head_idx in range(num_heads): + for row_idx in range(num_blocks): + row_u = row_idx * granularity + row_d = row_u + granularity + bar_l = bar_cnt[batch_idx, head_idx, row_idx, rank] + bar_r = bar_cnt[batch_idx, head_idx, row_idx, rank + 1] + for col_idx in bar_idx[batch_idx, head_idx, row_idx, bar_l:bar_r]: + mask[batch_idx, head_idx, row_u:row_d, col_idx] = True + # Causal Mask + arange = torch.arange(0, num_tokens_pad, dtype=torch.int32, device=q.device) + mask.masked_fill_(arange[None, None, :, None] < arange[None, None, None, :], False) + return mask[:, :, :num_tokens, :num_tokens] + + +def convert_blockmask( + blockmask: torch.Tensor, # [world_size, batch_size, num_heads, num_blocks, num_blocks] + block_size_M: int, + block_size_N: int, +): + ratio = block_size_M // block_size_N + original_shape = blockmask.shape + blockmask = blockmask.to(dtype=torch.uint8) + blockmask = blockmask.unsqueeze(-1).tile([1] * len(original_shape) + [ratio]).reshape((*original_shape[:-1], -1)) + + # now block_mask is [world_size, batch_size, num_heads, num_blocks, num_blocks * ratio] + nonzero_val, nonzero_idx = blockmask.sort(dim=-1, stable=True, descending=True) + + nonzero_rowcnt = blockmask.sum(dim=-1, dtype=torch.int32) + return nonzero_idx.contiguous().to(dtype=torch.int32), nonzero_rowcnt.contiguous() + diff --git a/minference/ops/op_utils/xattn_utils.py b/minference/ops/op_utils/xattn_utils.py index f80d3e7..9484c8b 100644 --- a/minference/ops/op_utils/xattn_utils.py +++ b/minference/ops/op_utils/xattn_utils.py @@ -1,8 +1,6 @@ -import math import torch import triton import triton.language as tl -import torch.nn.functional as F import torch.distributed as dist LN2 = 1 / 1.4426950408889634 diff --git a/minference/ops/minference_attn.py b/minference/ops/pit_sparse_flash_attention_v3.py similarity index 76% rename from minference/ops/minference_attn.py rename to minference/ops/pit_sparse_flash_attention_v3.py index 8396921..b198b82 100644 --- a/minference/ops/minference_attn.py +++ b/minference/ops/pit_sparse_flash_attention_v3.py @@ -24,7 +24,7 @@ sys.setdlopenflags(original_flags) # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr -from .utils import build_index_local +from .op_utils.vertical_slash_utils import build_index_local def block_attn_fwd( @@ -234,164 +234,6 @@ def _triton_bar_attn_fwd_kernel( tl.store(o_ptrs, acc.to(Out.type.element_ty), mask=m_mask[:, None]) tl.store(lse_ptrs, s, mask=(m_mask & overflow_mask)) -def stable_sigmoid(x: torch.Tensor): - return torch.where( - x >= 0, - 1 / (1 + torch.exp(-x)), - torch.exp(x) / (1 + torch.exp(x)) - ) - -def naive_bar_attn_fwd( - q, k, v, - sm_scale, - bar_cnt, bar_idx, - out, softmax_lse, - step, granularity, BLOCK_N=64 - ): - """ - Naive PyTorch implementation of the Triton bar attention forward kernel. - - Args: - q: Query tensor of shape [B, num_qo_heads, num_tokens, head_dim] - k: Key tensor of shape [B, num_kv_heads, num_tokens, head_dim] - v: Value tensor of shape [B, num_kv_heads, num_tokens, head_dim] - sm_scale: A scalar (float) softmax scale. - bar_cnt: Tensor of shape [B, num_qo_heads, num_blocks, world_size+1] - where each block (row) holds bar boundary indices. - bar_idx: Tensor of shape [B, num_qo_heads, num_blocks, nnz_v] - containing indices of keys (columns) for each block. - out: Output tensor of shape [B, num_qo_heads, num_tokens, head_dim]. - This is assumed to have a previous value to merge with. - softmax_lse: Tensor of shape [B, num_qo_heads, num_tokens] containing - the previous log-sum-exp values. - step: integer step indicating which pair of boundaries to use in bar_cnt. - granularity: BLOCK_M, i.e. the number of query tokens processed per block. - BLOCK_N: Block size for the key dimension (default: 64) - - This function updates `out` and `softmax_lse` in-place. - """ - # Get dimensions from q. - B, num_tokens, num_qo_heads, head_dim = q.shape - - # Determine number of query blocks (each corresponding to a row in bar_cnt/bar_idx). - num_blocks = math.ceil(num_tokens / granularity) - - # Compute the ratio for mapping query-head to key/value head. - head_ratio = num_qo_heads // k.shape[2] # since k.shape[1] is num_kv_heads - - # Precompute scale for q: note that 1.44269504 = log2(e) - qk_scale = sm_scale * 1.44269504 - - ln2 = 0.69314718 # constant for converting exp2 to exp - - # Loop over batch and query-head - for b in range(B): - for qh in range(num_qo_heads): - # corresponding key/value head index - kvh = qh // head_ratio - - # Loop over query blocks (rows) - for block in range(num_blocks): - start_m = block * granularity - end_m = min(start_m + granularity, num_tokens) - block_size = end_m - start_m - - # Get bar boundaries for this block & step: - # bar_cnt is assumed to store cumulative indices per block. - bar_l = bar_cnt[b, qh, block, step].item() # starting index (inclusive) - bar_r = bar_cnt[b, qh, block, step + 1].item() # ending index (exclusive) - if bar_l >= bar_r: - continue # nothing to do in this block - - # Initialize accumulators per query token in the block. - # m_i tracks the running maximum (in "log2" domain). - m_i = torch.full((block_size,), -float('inf'), device=q.device, dtype=torch.float32) - # l_i tracks the running sum-of-weights. - l_i = torch.zeros(block_size, device=q.device, dtype=torch.float32) - # acc accumulates the weighted sum of values. - acc = torch.zeros((block_size, head_dim), device=q.device, dtype=torch.float32) - - # Load and scale the q block. - # Shape: [block_size, head_dim] - q_block = q[b, start_m:end_m, qh, :] * qk_scale - - # Loop over key indices in steps of BLOCK_N - for n_start in range(bar_l, bar_r, BLOCK_N): - n_end = min(n_start + BLOCK_N, bar_r) - - # Load column indices from bar_idx. - # bar_idx shape: [nnz_v] for this block. - cols = bar_idx[b, qh, block, n_start:n_end] - cols = cols.long() - - k_selected = k[b, cols, kvh, :] # shape: [n_valid, head_dim] - v_selected = v[b, cols, kvh, :] # shape: [n_valid, head_dim] - - # Compute scaled dot product: [block_size, head_dim] x [head_dim, n_valid] - # Result: [block_size, n_valid] - qk = torch.matmul(q_block, k_selected.T) - - # Numerically stable softmax update in the log2 domain. - # m_i_new = max(m_i, max(qk, dim=1)) - cur_max, _ = qk.max(dim=1) - m_i_new = torch.max(m_i, cur_max) - - alpha = torch.exp2((m_i - m_i_new)) - p = torch.exp2((qk - m_i_new.unsqueeze(1))) - - # Update acc and l_i. - # Scale previous acc by alpha. - acc = acc * alpha.unsqueeze(1) + torch.matmul(p.to(q.dtype), v_selected) - - l_i = l_i * alpha + p.sum(dim=1) - - # Update m_i to the new maximum. - m_i = m_i_new - - # check zeros in l_i, if any, print out the indices - if (l_i == 0).any(): - zero_indices = torch.nonzero(l_i == 0).squeeze() - print(f"Rank {dist.get_rank()} | Zeros in l_i (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {zero_indices}") - - # Finalize the block output. - # Compute weighted output. - acc_1 = acc / l_i.unsqueeze(1) - s_1 = m_i * ln2 + torch.log(l_i) - # check positive infinity in s_1, if any, print out the indices - if torch.isinf(s_1).any() and ( torch.isinf(s_1) & (s_1 > 0) ).any(): - mask = torch.isinf(s_1) & (s_1 > 0) - print(f"Rank {dist.get_rank()} | Positive infinity in s_1 (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") - # check negative infinity in s_1, if any, print out the indices - if torch.isinf(s_1).any() and ( torch.isinf(s_1) & (s_1 < 0) ).any(): - mask = torch.isinf(s_1) & (s_1 < 0) - print(f"Rank {dist.get_rank()} | Negative infinity in s_1 (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") - - # Load previous stored values (accumulated output and LSE). - old_out = out[b, start_m:end_m, qh, :].to(acc_1.dtype) - old_lse = softmax_lse[b, qh, start_m:end_m] - # check positive infinity in old_lse, if any, print out the indices - if torch.isinf(old_lse).any() and ( torch.isinf(old_lse) & (old_lse > 0) ).any(): - mask = torch.isinf(old_lse) & (old_lse > 0) - print(f"Rank {dist.get_rank()} | Positive infinity in old_lse (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") - - # ------------------------------------------------- - # Logsigmoid solution - # out - old_out, block_out - acc1, lse - old_lse, block_lse - s_1 - new_out = old_out - F.sigmoid(s_1 - old_lse).unsqueeze(1) * (old_out - acc_1) - new_lse = s_1 - F.logsigmoid(s_1 - old_lse) - if torch.isinf(new_lse).any() and ( torch.isinf(new_lse) & (new_lse > 0) ).any(): - mask = torch.isinf(new_lse) & (new_lse > 0) - print(f"Rank {dist.get_rank()} | Positive infinity in new_lse (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {torch.nonzero(mask).squeeze()}") - - pos_inf_indices = torch.nonzero(mask).squeeze() - print(f"Rank {dist.get_rank()} | Values of (old_lse - s_1) resulting in pos-inf in theta (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {(old_lse - s_1)[pos_inf_indices]}") - print(f"Rank {dist.get_rank()} | Values of (old_lse) resulting in pos-inf in theta (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {(old_lse)[pos_inf_indices]}") - print(f"Rank {dist.get_rank()} | Values of (s_1) resulting in pos-inf in theta (step = {step}, batch_idx={b}, head_idx={qh}, block={block}): {(s_1)[pos_inf_indices]}") - - # Store back into out and softmax_lse. - out[b, start_m:end_m, qh, :] = new_out.to(out.dtype) - softmax_lse[b, qh, start_m:end_m] = new_lse - return out, softmax_lse def bar_attn_fwd( diff --git a/minference/ops/minference_attn_triton.py b/minference/ops/pit_sparse_flash_attention_v3_triton.py similarity index 86% rename from minference/ops/minference_attn_triton.py rename to minference/ops/pit_sparse_flash_attention_v3_triton.py index 0e76ba3..068085b 100644 --- a/minference/ops/minference_attn_triton.py +++ b/minference/ops/pit_sparse_flash_attention_v3_triton.py @@ -1,19 +1,10 @@ -import os -import sys -import math -import ctypes - import torch -import torch.nn.functional as F -import torch.distributed as dist - import triton import triton.language as tl - from typing import List, Tuple -from .utils import ( + +from .op_utils.vertical_slash_utils import ( build_index_local, _build_mask_local, convert_blockmask, - calc_index_local, convert_indices ) @@ -839,28 +830,6 @@ def block_bar_attn_bwd( return dq, dk.to(dq.dtype), dv.to(dq.dtype) -def build_index_local( - q: torch.Tensor, - k: torch.Tensor, - v_size: List[int], - s_size: List[int], - num_tokens: int, - granularity: int, - world_size: int = 1, - rank: int = 0, -): - if type(v_size) is list: - assert len(v_size) == q.shape[2] - assert len(s_size) == q.shape[2] - v_idx, s_idx = calc_index_local(q, k, v_size, s_size, last_q_size=64) - else: - v_idx, s_idx = v_size, s_size - num_blocks = triton.cdiv(num_tokens, granularity) - block_mask, bar_idx, bar_cnt = convert_indices(v_idx, s_idx, world_size, rank, num_blocks, granularity) - block_mask = block_mask[rank] - return block_mask, bar_idx, bar_cnt - - class MInferenceAttnTritonFunc(torch.autograd.Function): @staticmethod def forward( @@ -1105,126 +1074,3 @@ def _torch_sparse_attn_kvpacked_func( return_attn_probs, ) - -def profile(func, inputs, num_warmups=10, num_iters=10): - torch.cuda.synchronize() - for _ in range(num_warmups): - func(*inputs) - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(num_iters): - func(*inputs) - end.record() - torch.cuda.synchronize() - latency = start.elapsed_time(end) / num_iters - return latency - - -def print_compute_sparsity( - q: torch.Tensor, - k: torch.Tensor, - batch_size: int, - num_tokens: int, - num_qo_heads: int, - v_size: List[int], - s_size: List[int], - sparsity: float, - granularity: int = 128, -): - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) - num_blocks = block_mask.shape[-1] - causal_blocks = batch_size * num_qo_heads * num_blocks * (num_blocks + 1) / 2 - avg_blocks = block_mask.sum(dim=-1, dtype=torch.float32).mean().item() - block_ratio = block_mask.sum(dtype=torch.float32).item() / causal_blocks - avg_bars = bar_cnt[..., -1].mean(dtype=torch.float32).item() - bar_ratio = avg_bars / (causal_blocks / (num_qo_heads * num_blocks * num_blocks) * num_tokens) - compute_sparsity = 1 - block_ratio - bar_ratio - print(f"Max {max(v_size)} V Lines => {avg_bars:.2f} / {num_tokens} = {100 * bar_ratio:.1f}% Bar Ratio") - print(f"Max {max(s_size)} S Lines => {avg_blocks:.2f} / {num_blocks} = {100 * block_ratio:.1f}% Block Ratio") - print(f"Mask Sparsity = {100 * sparsity:.1f}%, Compute Sparsity = {100 * compute_sparsity:.1f}% ({(1 - compute_sparsity):.3f})") - - -def test_minference_attn( - batch_size: int, - num_tokens: int, - num_qo_heads: int, - num_kv_heads: int, - head_dim: int, - sparsity: float = 0.0, - granularity: int = 128, - check_results: bool = False, - profile_latency: bool = False, - dtype: torch.dtype = torch.bfloat16, - device: torch.device = 'cuda', - seed: int = 2025, -): - assert not (num_tokens > 8192 and check_results) - torch.manual_seed(seed) - q = torch.randn((batch_size, num_tokens, num_qo_heads, head_dim), requires_grad=True, dtype=dtype, device=device) - kv = torch.randn((batch_size, num_tokens, 2, num_kv_heads, head_dim), requires_grad=True, dtype=dtype, device=device) - grad = torch.randn((batch_size, num_tokens, num_qo_heads, head_dim), requires_grad=False, dtype=dtype, device=device) - v_size = [int((1 - sparsity) * 0.5 * num_tokens)] * num_qo_heads - s_size = [int((1 - sparsity) * 0.5 * num_tokens)] * num_qo_heads - print(f"[B, Hq, Hk, N, D] = [{batch_size}, {num_qo_heads}, {num_kv_heads}, {num_tokens}, {head_dim}]") - print_compute_sparsity(q, kv[:, :, 0], batch_size, num_tokens, num_qo_heads, v_size, s_size, sparsity, granularity) - - def call_attn(attn, inputs, grad=None, backward=False): - o, lse, _ = attn(**inputs) - if backward: - q.grad = None - kv.grad = None - o.backward(grad) - dq = q.grad.clone() - dkv = kv.grad.clone() - return o, lse, dq, dkv - return o, lse - - sparse_inputs = { - 'q': q, 'kv': kv, 'v_size': v_size, 's_size': s_size, - 'granularity': granularity, 'return_attn_probs': True, - } - dense_inputs = { - 'q': q, 'kv': kv, - 'return_attn_probs': True, 'causal': True, - } - - if check_results: - o, lse, dq, dkv = call_attn(minference_flash_attn_triton_kvpacked_func, sparse_inputs, grad=grad, backward=True) - o_ref, lse_ref, dq_ref, dkv_ref = call_attn(_torch_sparse_attn_kvpacked_func, sparse_inputs, grad=grad, backward=True) - # import ipdb; ipdb.set_trace() - htol, stol = { torch.float16: (1e-2, 1e-3), torch.bfloat16: (5e-2, 1e-2) }[dtype] - torch.testing.assert_close(o, o_ref, atol=htol, rtol=htol) - torch.testing.assert_close(lse, lse_ref, atol=stol, rtol=stol) - torch.testing.assert_close(dq, dq_ref, atol=htol, rtol=htol) - torch.testing.assert_close(dkv, dkv_ref, atol=htol, rtol=htol) - - if profile_latency: - from flash_attn import flash_attn_kvpacked_func - flash_latency = profile(call_attn, [flash_attn_kvpacked_func, dense_inputs, grad, True]) - flash_fwd_latency = profile(call_attn, [flash_attn_kvpacked_func, dense_inputs, None, False]) - flash_bwd_latency = flash_latency - flash_fwd_latency - minfer_latency = profile(call_attn, [minference_flash_attn_triton_kvpacked_func, sparse_inputs, grad, True]) - minfer_fwd_latency = profile(call_attn, [minference_flash_attn_triton_kvpacked_func, sparse_inputs, None, False]) - minfer_idx_latency = profile(build_index_local, [q, kv[:, :, 0], v_size, s_size, num_tokens, granularity]) - minfer_bwd_latency = minfer_latency - minfer_fwd_latency - minfer_fwd_latency = minfer_fwd_latency - minfer_idx_latency - import pandas as pd - df = pd.DataFrame( - data=[ - [minfer_idx_latency, minfer_fwd_latency, minfer_bwd_latency], - [0, flash_fwd_latency, flash_bwd_latency], - [None, minfer_fwd_latency / flash_fwd_latency, minfer_bwd_latency / flash_bwd_latency] - ], - index=['MInfer', 'Flash', 'Ratio'], - columns=['Index', 'Forward', 'Backward'], - ).round(2) - print("-" * 64) - print(df) - - -if __name__ == '__main__': - print("=" * 64) - test_minference_attn(1, 131072, 4, 2, 128, sparsity=0.998, check_results=False, profile_latency=True) - \ No newline at end of file diff --git a/minference/ops/utils.py b/minference/ops/utils.py index 68503ef..ebe4d7e 100644 --- a/minference/ops/utils.py +++ b/minference/ops/utils.py @@ -1,994 +1,6 @@ -import os -import numpy as np -from typing import List -from functools import lru_cache - import torch -import torch.distributed as dist - -import triton -import triton.language as tl - def set_seed(seed=42): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - -@triton.jit -def _triton_extract_kv_kernel( - local_k, local_v, bar_k, bar_v, v_idx, v_cnt, - stride_lz, stride_ln, stride_lh, stride_ld, - stride_bz, stride_bn, stride_bh, stride_bd, - stride_iz, stride_ih, stride_in, - stride_cz, stride_ch, stride_cr, - step, num_tokens, num_qo_heads, num_kv_heads, - BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, -): - start_n = tl.program_id(0) - qo_head_idx = tl.program_id(1) - batch_idx = tl.program_id(2) - kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) - - v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch - min_n = tl.load(v_cnt_ptr + step * stride_cr) - max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) - start_n = start_n * BLOCK_N - end_n = start_n + BLOCK_N - if start_n >= max_n or end_n <= min_n: - return - - offs_d = tl.arange(0, BLOCK_D) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_n = (offs_n >= min_n) & (offs_n < max_n) - - v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih - local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld - local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld - bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd - bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd - - # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens - idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens - k = tl.load(local_k_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) - v = tl.load(local_v_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) - tl.store(bar_k_ptr + offs_n[:, None] * stride_bn, k, mask=mask_n[:, None]) - tl.store(bar_v_ptr + offs_n[:, None] * stride_bn, v, mask=mask_n[:, None]) - - -def extract_kv( - local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] - bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] - v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] - v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] - step: int, -): - batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape - _, num_tokens, num_kv_heads, _ = local_k.shape - block_N = 128 - block_D = head_dim - _triton_extract_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( - local_k, local_v, bar_k, bar_v, v_idx, v_cnt, - local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), - bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), - v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), - v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), - step, num_tokens, num_qo_heads, num_kv_heads, - BLOCK_N=block_N, BLOCK_D=block_D, - num_warps=4, num_stages=1, - ) - - -@triton.jit -def _triton_merge_kv_kernel( - local_k, local_v, bar_k, bar_v, v_idx, v_cnt, - stride_lz, stride_ln, stride_lh, stride_ld, - stride_bz, stride_bn, stride_bh, stride_bd, - stride_iz, stride_ih, stride_in, - stride_cz, stride_ch, stride_cr, - step, num_tokens, num_qo_heads, num_kv_heads, - BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, -): - start_n = tl.program_id(0) - qo_head_idx = tl.program_id(1) - batch_idx = tl.program_id(2) - kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) - - v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch - min_n = tl.load(v_cnt_ptr + step * stride_cr) - max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) - start_n = start_n * BLOCK_N - end_n = start_n + BLOCK_N - if start_n >= max_n or end_n <= min_n: - return - - offs_d = tl.arange(0, BLOCK_D) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_n = (offs_n >= min_n) & (offs_n < max_n) - - v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih - local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld - local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld - bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd - bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd - - # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens - idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens - k = tl.load(bar_k_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_k.type.element_ty) - v = tl.load(bar_v_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_v.type.element_ty) - tl.atomic_add(local_k_ptr + idx[:, None] * stride_ln, k, mask=mask_n[:, None], sem="relaxed") - tl.atomic_add(local_v_ptr + idx[:, None] * stride_ln, v, mask=mask_n[:, None], sem="relaxed") - - -def merge_kv( - local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] - bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] - v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] - v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] - step: int, -): - batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape - _, num_tokens, num_kv_heads, _ = local_k.shape - block_N = 128 - block_D = head_dim - _triton_merge_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( - local_k, local_v, bar_k, bar_v, v_idx, v_cnt, - local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), - bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), - v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), - v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), - step, num_tokens, num_qo_heads, num_kv_heads, - BLOCK_N=block_N, BLOCK_D=block_D, - num_warps=4, num_stages=1, - ) - - -# triton.cdiv(world_size * num_blocks, BLOCK_N), num_heads, batch_size -# block_mask: [batch_size, num_heads, num_blocks_global] -@triton.jit -def _calc_block_mask_kernel( - s_idx, block_mask, - stride_sz, stride_sh, stride_sk, - stride_bz, stride_bh, stride_bn, - max_s_size, num_tokens, granularity, - BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, -): - batch_idx = tl.program_id(2) - head_idx = tl.program_id(1) - group_idx = tl.program_id(0) - - block_offs = tl.arange(0, BLOCK_N) - slash_offs = tl.arange(0, BLOCK_K) - - s_idx_ptr = s_idx + batch_idx * stride_sz + head_idx * stride_sh - block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh - block_idx = group_idx * BLOCK_N + block_offs - - blocks = tl.zeros([BLOCK_N], dtype=tl.uint8) - for s_off in range(0, max_s_size, BLOCK_K): - s = tl.load(s_idx_ptr + (s_off + slash_offs) * stride_sk) - left = (num_tokens - granularity - s) // granularity - right = (num_tokens - 1 - s) // granularity - - # mask is generated by checking if a block's index falls between the calculated ranges - blocks |= tl.max((block_idx[None, :] >= left[:, None]) & (block_idx[None, :] <= right[:, None]), 0).to(tl.uint8) - - b_mask = (group_idx * BLOCK_N + block_offs) * granularity < num_tokens - tl.store(block_mask_ptr + (group_idx * BLOCK_N + block_offs) * stride_bn, blocks, mask=b_mask) - - -@triton.jit -def _striped_convert_indices_kernel( - last_row_mask, v_idx, v_cnt, - block_mask, bar_idx, bar_pos, bar_cnt, - stride_rz, stride_rh, stride_rn, - stride_vz, stride_vh, stride_vk, - stride_nz, stride_nh, stride_nt, - stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, - stride_iz, stride_ih, stride_im, stride_ik, - stride_cz, stride_ch, stride_cm, stride_ct, - max_v_size, num_blocks, granularity, world_size, rank, - BLOCK_N: tl.constexpr, -): - batch_idx = tl.program_id(2) - head_idx = tl.program_id(1) - block_idx_q_local = tl.program_id(0) - - block_idx_q_global = block_idx_q_local * world_size + rank - - num_tokens_local = num_blocks * granularity - num_blocks_global = world_size * num_blocks - shift = num_blocks_global - 1 - block_idx_q_global - - block_offs = tl.arange(0, BLOCK_N) - - last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh - v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh - v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh - block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm - bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im - bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im - bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm - - cnt_valid = 0 - cnt_all = 0 - v_off = 0 - v = tl.load(v_idx_ptr + cnt_all * stride_vk) - cnt_all += 1 - - tl.store(bar_cnt_ptr, cnt_valid) - bar_cnt_ptr += stride_ct - if block_idx_q_local == tl.num_programs(0) - 1: - tl.store(v_cnt_ptr, cnt_all - 1) - v_cnt_ptr += stride_nt - - for step in range(world_size): - for block_off_k in range(0, num_blocks, BLOCK_N): - block_idx_k_local = block_off_k + block_offs - block_idx_k_global = (block_off_k + block_offs) * world_size + step - mask_local = tl.load( - last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, - mask=(block_idx_k_global + shift < num_blocks_global), - other=0, - ) - tl.store( - block_mask_ptr + block_idx_k_local * stride_bn, - mask_local, - mask=(block_idx_k_local < num_blocks), - ) - block_left = v_off + block_idx_k_local * granularity - block_right = block_left + granularity - max_blocks = block_idx_q_local + 1 if step <= rank else block_idx_q_local - v_max = v_off + min(block_off_k + BLOCK_N, max_blocks) * granularity - while v < v_max and cnt_all < max_v_size: - if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): - tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) - tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) - cnt_valid += 1 - v = tl.load(v_idx_ptr + cnt_all * stride_vk) - cnt_all += 1 - block_mask_ptr += stride_bt - tl.store(bar_cnt_ptr, cnt_valid) - bar_cnt_ptr += stride_ct - v_off += num_tokens_local - if block_idx_q_local == tl.num_programs(0) - 1: - tl.store(v_cnt_ptr, cnt_all - 1) - v_cnt_ptr += stride_nt - - -@triton.jit -def _zigzag_convert_indices_kernel( - last_row_mask, v_idx, v_cnt, - block_mask, bar_idx, bar_pos, bar_cnt, - stride_rz, stride_rh, stride_rn, - stride_vz, stride_vh, stride_vk, - stride_nz, stride_nh, stride_nt, - stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, - stride_iz, stride_ih, stride_im, stride_ik, - stride_cz, stride_ch, stride_cm, stride_ct, - max_v_size, num_blocks, granularity, world_size, rank, - BLOCK_N: tl.constexpr, -): - batch_idx = tl.program_id(2) - head_idx = tl.program_id(1) - block_idx_q_local = tl.program_id(0) - - if rank < world_size // 2: - revert_rank = rank * 2 - else: - revert_rank = (world_size - 1 - rank) * 2 + 1 - if block_idx_q_local < num_blocks // 2: - block_idx_q_global = revert_rank * (num_blocks // 2) + block_idx_q_local - else: - block_idx_q_global = (world_size * 2 - 1 - revert_rank) * (num_blocks // 2) + block_idx_q_local - (num_blocks // 2) - - num_blocks_global = world_size * num_blocks - shift = num_blocks_global - 1 - block_idx_q_global - - block_offs = tl.arange(0, BLOCK_N) - - last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh - v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh - v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh - block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm - bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im - bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im - bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm - - cnt_valid = 0 - cnt_all = 0 - v = tl.load(v_idx_ptr + cnt_all * stride_vk) - cnt_all += 1 - - tl.store(bar_cnt_ptr, cnt_valid) - bar_cnt_ptr += stride_ct - if block_idx_q_local == tl.num_programs(0) - 1: - tl.store(v_cnt_ptr, cnt_all - 1) - v_cnt_ptr += stride_nt - - for step in range(world_size): - v_off = step * num_blocks * granularity - v_end = v_off + num_blocks * granularity - for block_off_k in range(0, num_blocks, BLOCK_N): - block_idx_k_local = block_off_k + block_offs - # assert BLOCK_N <= num_blocks // 2 - if block_off_k < num_blocks // 2: - v_off_global = step * (num_blocks // 2) * granularity - block_idx_k_global = step * (num_blocks // 2) + block_idx_k_local - else: - v_off_global = (world_size * 2 - 2 - step) * (num_blocks // 2) * granularity - block_idx_k_global = (world_size * 2 - 1 - step) * (num_blocks // 2) + block_idx_k_local - (num_blocks // 2) - mask_local = tl.load( - last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, - mask=(block_idx_k_global + shift < num_blocks_global), - other=0, - ) - tl.store( - block_mask_ptr + block_idx_k_local * stride_bn, - mask_local, - mask=(block_idx_k_local < num_blocks), - ) - # block_left = block_idx_k_global * granularity - v_off_global + v_off - # block_right = block_left + granularity - block_left = v_off + block_idx_k_local * granularity - block_right = block_left + granularity - v_max = (block_idx_q_global + 1) * granularity - v_off_global + v_off - while v < v_end and cnt_all <= max_v_size: - if v < v_max: - if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): - tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) - tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) - cnt_valid += 1 - v = tl.load(v_idx_ptr + cnt_all * stride_vk) - cnt_all += 1 - block_mask_ptr += stride_bt - tl.store(bar_cnt_ptr, cnt_valid) - bar_cnt_ptr += stride_ct - if block_idx_q_local == tl.num_programs(0) - 1: - tl.store(v_cnt_ptr, cnt_all - 1) - v_cnt_ptr += stride_nt - - -def convert_indices( - v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] - s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] - world_size: int, - rank: int, - num_blocks: int, - granularity: int, - num_tokens: int = None, - stripe_transform: bool = False, - zigzag_transform: bool = False, -): - num_blocks_global = world_size * num_blocks - if num_tokens is None: - # Note that for each invokation of `convert_indices`, `num_tokens` is None and becomes the **global number of tokens** - num_tokens = num_blocks_global * granularity - batch_size, num_heads, max_v_size = v_idx.shape - batch_size, num_heads, max_s_size = s_idx.shape - last_row_mask = torch.zeros((batch_size, num_heads, num_blocks_global), dtype=torch.bool, device=s_idx.device) - - BLOCK_N, BLOCK_K = 128, 128 - assert max_s_size <= BLOCK_K * BLOCK_K, f"max_s_size={max_s_size} > BLOCK_K * BLOCK_K={BLOCK_K * BLOCK_K}" - _calc_block_mask_kernel[(triton.cdiv(num_blocks_global, BLOCK_N), num_heads, batch_size)]( - s_idx, last_row_mask, - s_idx.stride(0), s_idx.stride(1), s_idx.stride(2), - last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), - max_s_size, num_tokens, granularity, - BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - num_warps=4, num_stages=2, - ) - - block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) - bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) - bar_cnt = torch.empty((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) - v_cnt = torch.empty((batch_size, num_heads, world_size + 1), dtype=torch.int32, device=v_idx.device) - bar_pos = torch.zeros_like(bar_idx) - if zigzag_transform: - convert_indices_kernel = _zigzag_convert_indices_kernel - assert num_blocks % 2 == 0 - BLOCK_N = max(num_blocks // 2, 128) - else: - convert_indices_kernel = _striped_convert_indices_kernel - BLOCK_N = 128 - convert_indices_kernel[(num_blocks, num_heads, batch_size)]( - last_row_mask, v_idx, v_cnt, block_mask, bar_idx, bar_pos, bar_cnt, - last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), - v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), - v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), - block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3), block_mask.stride(4), - bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), - bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), - max_v_size, num_blocks, granularity, world_size, rank, BLOCK_N=BLOCK_N, - num_warps=1, num_stages=1, - ) - # if zigzag_transform: - # if rank == 0: - # import ipdb; ipdb.set_trace() - # torch.save(block_mask, f'./output/data/block_mask_{rank}.pt') - # torch.save(bar_idx, f'./output/data/bar_idx_{rank}.pt') - # torch.save(bar_cnt, f'./output/data/bar_cnt_{rank}.pt') - # elif rank == 0: - # torch.save(block_mask, f'./output/data/block_mask.pt') - # torch.save(bar_idx, f'./output/data/bar_idx.pt') - # torch.save(bar_cnt, f'./output/data/bar_cnt.pt') - # bar_cnt = torch.zeros_like(bar_cnt) - return block_mask, bar_idx, bar_cnt, bar_pos, v_cnt - - -def _torch_convert_indices( - v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] - s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] - world_size: int, - rank: int, - num_blocks: int, - granularity: int, -): - batch_size, num_heads, max_v_size = v_idx.shape - num_tokens = world_size * num_blocks * granularity - block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) - bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) - bar_cnt = torch.zeros((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) - for batch_idx in range(batch_size): - for head_idx in range(num_heads): - for block_idx_q in range(num_blocks): - block_idx_q_global = block_idx_q * world_size + rank - cnt_all, cnt_valid = 0, 0 - for step in range(world_size): - for block_idx_k in range(block_idx_q + 1): - block_idx_k_global = block_idx_k * world_size + step - s_min = max((block_idx_q_global - block_idx_k_global - 1) * granularity, 0) - s_max = (block_idx_q_global - block_idx_k_global + 1) * granularity - flag = torch.any((s_idx[batch_idx, head_idx] > s_min) & (s_idx[batch_idx, head_idx] < s_max)) - block_mask[step, batch_idx, head_idx, block_idx_q, block_idx_k] = flag - v_min = (step * num_blocks + block_idx_k) * granularity - max_blocks = block_idx_q + 1 if step <= rank else block_idx_q - v_max = (step * num_blocks + min(block_idx_k + 1, max_blocks)) * granularity - while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_min: - cnt_all += 1 - while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_max: - if not flag: - bar_idx[batch_idx, head_idx, block_idx_q, cnt_valid] = \ - v_idx[batch_idx, head_idx, cnt_all] - step * num_blocks * granularity - cnt_valid += 1 - cnt_all += 1 - bar_cnt[batch_idx, head_idx, block_idx_q, step + 1] = cnt_valid - return block_mask, bar_idx, bar_cnt - - - -def sum_all_diagonal_matrix(mat: torch.Tensor): - b, h, m, n = mat.shape - - # Pads the matrix on left and right (on the last dimension) - mat_padded = torch.nn.functional.pad(mat, (m, m), "constant", 0.) # shape: [b, h, m, 2 * m + n] - # Change the strides - mat_strided = mat_padded.as_strided((b, h, m, m + n), (m * (2 * m + n) * h, m * (2 * m + n), 2 * m + n + 1, 1)) - # Sums the resulting matrix's columns - sum_diags = torch.sum(mat_strided, 2) # shape: [b, h, m + n] - return sum_diags[:, :, 1:].contiguous() - -def calc_index( - q: torch.Tensor, - k: torch.Tensor, - v_size: List[int], - s_size: List[int], - last_q_size: int = 64, - sink_tokens: int = 30, - sliding_window: int = 100, - group: dist.group = None, - stripe_transform: bool = False, - zigzag_transform: bool = False, - granularity: int = 128, -): - # TODO: adapt naturely striped inputs - # TODO: flex-prefill (top-P) - # TODO: reduce bubble - # TODO: support total_num_tokens % world_size != 0 - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): - flex_prefill = True - v_p = [x[0] for x in v_size] - v_size = [x[1] for x in v_size] - s_p = [x[0] for x in s_size] - s_size = [x[1] for x in s_size] - else: - flex_prefill = False - assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) - - max_v_size = min(triton.cdiv(max(v_size), 128), num_tokens // 128) * 128 - max_s_size = min(triton.cdiv(max(s_size), 128), num_tokens // 128) * 128 - - last_rank = world_size - 1 - if rank == last_rank: - last_q = q[:, -last_q_size:, :, :].detach().clone().reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)) - else: - last_q = torch.zeros((batch_size, last_q_size, num_kv_heads, num_qo_heads // num_kv_heads, head_dim), device=q.device, dtype=q.dtype) - - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking broadcast last_q from rank={last_rank}", flush=True) - dist.broadcast(last_q, src=last_rank, group=group, async_op=False) - - qk = torch.einsum('bmghd, bngd -> bghmn', last_q, k) * (k.shape[-1] ** -0.5) - qk = qk.reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) - - if rank == last_rank: - # Causal Mask, requires num_tokens // world_size >= last_q - arange = torch.arange(last_q_size, device=k.device) - mask = arange[None, None, :, None] >= arange[None, None, None, :] - qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) - if flex_prefill: # qk = torch.softmax(qk, dim=-1) / last_q_size - qk_max = torch.max(qk, dim=-1, keepdim=True).values - qk_max_list = [torch.empty_like(qk_max) for _ in range(world_size)] - dist.all_gather(qk_max_list, qk_max, group=group, async_op=False) - qk_max = torch.max(torch.stack(qk_max_list), dim=0).values - qk = torch.exp(qk - qk_max) - qk_sum = torch.sum(qk, dim=-1, keepdim=True) - qk_sum_list = [torch.empty_like(qk_sum) for _ in range(world_size)] - dist.all_gather(qk_sum_list, qk_sum, group=group, async_op=False) - qk_sum = torch.sum(torch.stack(qk_sum_list), dim=0) - qk /= (qk_sum * last_q_size) - - v_gather_rank = 0 - vertical = qk.sum(-2, keepdim=False) # [B, H, N_LOCAL] - if rank == 0 and not flex_prefill: - vertical[..., :sink_tokens] = torch.inf - if rank == v_gather_rank: - gathered_vertical = [torch.empty_like(vertical) for _ in range(world_size)] - else: - gathered_vertical = None - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking gather vertical to {v_gather_rank}", flush=True) - dist.gather(vertical, gathered_vertical, dst=v_gather_rank, group=group, async_op=False) - - if rank == v_gather_rank: - vertical: torch.Tensor = torch.cat(gathered_vertical, dim=-1) - if stripe_transform: - vertical = vertical.reshape((batch_size, num_qo_heads, -1, world_size, granularity)) - vertical = vertical.swapaxes(2, 3) - vertical = vertical.reshape((batch_size, num_qo_heads, -1)) - elif zigzag_transform: - vertical = vertical.reshape((batch_size, num_qo_heads, 2, world_size, -1)) - chunks = [] - for step in range(world_size): - chunks.append(vertical[:, :, 0, step]) - chunks.append(vertical[:, :, 1, world_size - 1 - step]) - vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) - - v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) - v_indices = v_topk.indices.to(torch.int32) - if flex_prefill: - v_cumsum = v_topk.values.cumsum_(dim=-1) - v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) - else: - v_size = torch.tensor(v_size, device=k.device)[None, :, None] - v_arange = torch.arange(max_v_size, device=k.device) - v_indices.masked_fill_(v_arange[None, None, :] >= v_size, num_tokens * world_size) - v_indices = v_indices.sort(dim=-1, descending=False).values - else: - v_indices = torch.empty((batch_size, num_qo_heads, max_v_size), dtype=torch.int32, device=k.device) - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking broadcast v_indices from rank={v_gather_rank}", flush=True) - dist.broadcast(v_indices, src=v_gather_rank, group=group, async_op=False) # async - - s_gather_rank = 0 - slash = sum_all_diagonal_matrix(qk) # shape: [B, H, N_LOCAL + LAST_Q_SIZE - 1] - if rank == world_size - 1 and not flex_prefill: - # -> index starting from the left bottom corner to right upper corner - # (sliding_window) from -(last_q_size-1) is the sliding window close to the main diagonal - slash[..., -(last_q_size - 1 + sliding_window):] = torch.inf - - - if rank == s_gather_rank: - gathered_slash = [torch.empty_like(slash) for _ in range(world_size)] - else: - gathered_slash = None - - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking gather slash to rank=0", flush=True) - dist.gather(slash, gathered_slash, dst=s_gather_rank, group=group, async_op=False) - - if rank == s_gather_rank: - slash = gathered_slash[0] - for next_slash in gathered_slash[1:]: - slash[..., -last_q_size + 1:] += next_slash[..., :last_q_size - 1] - slash = torch.cat((slash, next_slash[..., last_q_size - 1:]), dim=-1) - - # slash presents the sum of attention from 0-th to (num_tokens_global - last_q_size - 1), where 0 represents the diagonal at bottom left corner - slash = slash[..., :-last_q_size + 1] - s_topk = torch.topk(slash, max_s_size, -1, sorted=True) - - # s_indices contain indices starting from the right upper corner to left bottom corner - s_indices = (num_tokens * world_size - 1) - s_topk.indices.to(torch.int32) - if flex_prefill: - s_cumsum = s_topk.values.cumsum_(dim=-1) - s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) - else: - s_size = torch.tensor(s_size, device=k.device)[None, :, None] - s_arange = torch.arange(max_s_size, device=k.device) - s_indices.masked_fill_(s_arange[None, None, :] >= s_size, -1) - s_indices = s_indices.sort(dim=-1, descending=True).values - else: - s_indices = torch.empty((batch_size, num_qo_heads, max_s_size), dtype=torch.int32, device=k.device) - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking broadcast s_indices from rank={s_gather_rank}", flush=True) - dist.broadcast(s_indices, src=s_gather_rank, group=group, async_op=False) - - return v_indices.to(torch.int32), s_indices.to(torch.int32) - -def calc_index_local( - q: torch.Tensor, - k: torch.Tensor, - v_size: List[int], - s_size: List[int], - last_q_size: int = 64, - sink_tokens: int = 30, - sliding_window: int = 100, - group: dist.group = None, - stripe_transform: bool = False, - zigzag_transform: bool = False, - granularity: int = 128, -): - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - - if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): - flex_prefill = True - v_p = [x[0] for x in v_size] - v_size = [x[1] for x in v_size] - s_p = [x[0] for x in s_size] - s_size = [x[1] for x in s_size] - else: - flex_prefill = False - assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) - - qk = torch.einsum( - f'bmghd, bngd -> bghmn', - q[:, -last_q_size:, :, :].reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)), - k, - ).reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) * (head_dim ** -0.5) - - arange = torch.arange(last_q_size, device=k.device) - mask = arange[None, None, :, None] >= arange[None, None, None, :] - qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) - if flex_prefill: - qk = torch.softmax(qk, dim=-1) / last_q_size - - max_v_size = min(max(v_size), num_tokens) - max_v_size = triton.cdiv(max_v_size, 128) * 128 - vertical = qk.sum(-2, keepdim=False) - if not flex_prefill: - vertical[..., :sink_tokens] = torch.inf - if stripe_transform: - vertical = vertical.reshape((batch_size, num_qo_heads, -1, dist.get_world_size(group), granularity)) - vertical = vertical.swapaxes(2, 3) - vertical = vertical.reshape((batch_size, num_qo_heads, -1)) - elif zigzag_transform: - vertical = vertical.reshape((batch_size, num_qo_heads, 2, dist.get_world_size(group), -1)) - chunks = [] - for step in range(dist.get_world_size(group)): - chunks.append(vertical[:, :, 0, step]) - chunks.append(vertical[:, :, 1, dist.get_world_size(group) - 1 - step]) - vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) - v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) - v_indices = v_topk.indices - if flex_prefill: - v_cumsum = v_topk.values.cumsum_(dim=-1) - v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) - else: - v_size = torch.tensor(v_size, device=k.device)[None, :, None] - - max_s_size = min(max(s_size), num_tokens) - max_s_size = triton.cdiv(max_s_size, 128) * 128 - slash = sum_all_diagonal_matrix(qk)[..., :-last_q_size + 1] - if not flex_prefill: - slash[..., -sliding_window:] = torch.inf - s_topk = torch.topk(slash, max_s_size, -1, sorted=True) - s_indices = (num_tokens - 1) - s_topk.indices - if flex_prefill: - s_cumsum = s_topk.values.cumsum_(dim=-1) - s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) - else: - s_size = torch.tensor(s_size, device=k.device)[None, :, None] - - v_arange = torch.arange(max_v_size, device=k.device) - v_idx = v_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) - v_idx.masked_fill_(v_arange[None, None, :] >= v_size, 2147483647) - v_idx = v_idx.sort(dim=-1, descending=False).values - - s_arange = torch.arange(max_s_size, device=k.device) - s_idx = s_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) - s_idx.masked_fill_(s_arange[None, None, :] >= s_size, -1) - s_idx = s_idx.sort(dim=-1, descending=True).values - - return v_idx, s_idx - - -def profile(func, inputs, num_warmups=100, num_iters=100): - torch.cuda.synchronize() - for _ in range(num_warmups): - func(*inputs) - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(num_iters): - func(*inputs) - end.record() - torch.cuda.synchronize() - latency = start.elapsed_time(end) / num_iters - return latency - -def build_index_local( - q: torch.Tensor, - k: torch.Tensor, - v_size: List[int], - s_size: List[int], - num_tokens: int, - granularity: int, - world_size: int = 1, - rank: int = 0, -): - if type(v_size) is list: - assert len(v_size) == q.shape[2] - assert len(s_size) == q.shape[2] - v_idx, s_idx = calc_index_local(q, k, v_size, s_size, last_q_size=64) - else: - v_idx, s_idx = v_size, s_size - - num_blocks = triton.cdiv(num_tokens, granularity) - block_mask, bar_idx, bar_cnt, _, _ = convert_indices(v_idx, s_idx, world_size, rank, num_blocks, granularity) - block_mask = block_mask[rank] - return block_mask, bar_idx, bar_cnt - -def build_index( - q: torch.Tensor, - k: torch.Tensor, - v_size: List[int], - s_size: List[int], - num_tokens: int, # num_tokens_local - granularity: int, - stripe_transform: bool = True, - zigzag_transform: bool = False, - group: dist.group = None, -): - """ - Input: (all inputs correspond to the local part for each rank) - q: shape [batch_size, num_tokens_local, num_qo_heads, head_dim] - k: shape [batch_size, num_tokens_local, num_kv_heads, head_dim] - v_size: shape [num_qo_heads] - s_size: shape [num_qo_heads] - num_tokens: number of tokens in the local part of QK - Returns: - block_mask: shape [world_size, batch_size, num_heads, num_blocks, num_blocks] - bar_idx: shape [batch_size, num_heads, num_blocks, max_v_size] - bar_cnt: shape [batch_size, num_heads, num_blocks, world_size + 1], each entry is the cumulative number of selected bars corresponding a rank - """ - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - if isinstance(v_size, list): - v_idx, s_idx = calc_index( - q, k, v_size, s_size, last_q_size=64, group=group, - stripe_transform=stripe_transform, - zigzag_transform=zigzag_transform, - granularity=granularity - ) - else: - v_idx, s_idx = v_size, s_size - - num_blocks = triton.cdiv(num_tokens, granularity) # num_blocks_local - - # Note that block_mask is a 5D tensor with shape [world_size, batch_size, num_heads, num_blocks, num_blocks] - # with each block_mask[i] is to a mask corresponding the num_tokens_local x num_tokens_local matmul for each step - block_mask, bar_idx, bar_cnt, bar_pos, v_cnt = convert_indices( - v_idx, s_idx, world_size, rank, num_blocks, granularity, - stripe_transform=stripe_transform, - zigzag_transform=zigzag_transform, - ) - return block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt - - -def _build_mask_local( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], - s_size: List[int], - num_tokens: int, - granularity: int, - world_size: int = 1, - rank: int = 0, -): - with torch.no_grad(): - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity, world_size, rank) - batch_size, num_tokens, num_heads, head_dim = q.shape - num_blocks = block_mask.shape[-1] - num_tokens_pad = num_blocks * granularity - # Block Mask - mask = block_mask.unsqueeze(3).unsqueeze(5).repeat((1, 1, 1, granularity, 1, granularity)) - mask = mask.reshape((batch_size, num_heads, num_tokens_pad, num_tokens_pad)) - # Bar Mask - for batch_idx in range(batch_size): - for head_idx in range(num_heads): - for row_idx in range(num_blocks): - row_u = row_idx * granularity - row_d = row_u + granularity - bar_l = bar_cnt[batch_idx, head_idx, row_idx, rank] - bar_r = bar_cnt[batch_idx, head_idx, row_idx, rank + 1] - for col_idx in bar_idx[batch_idx, head_idx, row_idx, bar_l:bar_r]: - mask[batch_idx, head_idx, row_u:row_d, col_idx] = True - # Causal Mask - arange = torch.arange(0, num_tokens_pad, dtype=torch.int32, device=q.device) - mask.masked_fill_(arange[None, None, :, None] < arange[None, None, None, :], False) - return mask[:, :, :num_tokens, :num_tokens] - - -def convert_blockmask( - blockmask: torch.Tensor, # [world_size, batch_size, num_heads, num_blocks, num_blocks] - block_size_M: int, - block_size_N: int, -): - ratio = block_size_M // block_size_N - original_shape = blockmask.shape - blockmask = blockmask.to(dtype=torch.uint8) - blockmask = blockmask.unsqueeze(-1).tile([1] * len(original_shape) + [ratio]).reshape((*original_shape[:-1], -1)) - - # now block_mask is [world_size, batch_size, num_heads, num_blocks, num_blocks * ratio] - nonzero_val, nonzero_idx = blockmask.sort(dim=-1, stable=True, descending=True) - - nonzero_rowcnt = blockmask.sum(dim=-1, dtype=torch.int32) - return nonzero_idx.contiguous().to(dtype=torch.int32), nonzero_rowcnt.contiguous() - -def compute_sr_flops( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - bar_cnt_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] - granularity: int, - q_len: int, - head_dim: int, - shift: bool, - fwd: bool=True, -): - num_blocks = triton.cdiv(q_len, granularity) - bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] - - if not shift: - total_num_blocks = bh * num_blocks * (num_blocks + 1) / 2 - else: - total_num_blocks = bh * num_blocks * (num_blocks - 1) / 2 - - block_ratio = block_mask_offset.sum(dtype=torch.float32).item() / total_num_blocks - bar_ratio = bar_cnt_offset.sum(dtype=torch.float32).item() / (granularity * total_num_blocks) - sparsity_ratio = 1 - block_ratio - bar_ratio - - block_flops = block_mask_offset.sum(dtype=torch.float32).item() * (granularity * granularity) * head_dim * 2 * 2 - bar_flops = bar_cnt_offset.sum(dtype=torch.float32).item() * granularity * head_dim * 2 * 2 - flops = block_flops + bar_flops - - if not fwd: - flops, block_flops, bar_flops = 2.5 * flops, 2.5 * block_flops, 2.5 * bar_flops - return block_ratio, bar_ratio, sparsity_ratio, block_flops, bar_flops, flops - -def compute_sr_by_heads( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - bar_cnt_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] - granularity: int, - q_len: int, - head_dim: int, - shift: bool, - fwd: bool=True, -): - num_heads = block_mask_offset.shape[1] - num_blocks = triton.cdiv(q_len, granularity) - if not shift: - total_num_blocks = num_blocks * (num_blocks + 1) / 2 - else: - total_num_blocks = num_blocks * (num_blocks - 1) / 2 - total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) - - block_ratio_by_heads = block_mask_offset.sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) / total_num_blocks_by_heads - bar_ratio_by_heads = bar_cnt_offset.sum(dim=-1).sum(0, dtype=torch.float32) / total_num_blocks_by_heads / granularity - sparsity_ratio_by_heads = 1 - block_ratio_by_heads - bar_ratio_by_heads - - return sparsity_ratio_by_heads - -def get_compute_sparsity( - q: torch.Tensor, - k: torch.Tensor, - batch_size: int, - num_tokens: int, - num_qo_heads: int, - v_size: List[int], - s_size: List[int], - granularity: int = 128, -): - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) - num_blocks = block_mask.shape[-1] - causal_blocks = batch_size * num_qo_heads * num_blocks * (num_blocks + 1) / 2 - block_ratio = block_mask.sum(dtype=torch.float32).item() / causal_blocks - - avg_bars = bar_cnt[..., -1].mean(dtype=torch.float32).item() - bar_ratio = avg_bars / (causal_blocks / (num_qo_heads * num_blocks * num_blocks) * num_tokens) - - compute_sparsity = 1 - block_ratio - bar_ratio - - return compute_sparsity - - - -@lru_cache(maxsize=16) -def calc_chunks(cu_seqlen, moba_chunk_size): - """calc chunks that needs moba attention""" - - # batch_sizes[batch_idx] = batch size ( seqlen ) of batch idx - batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] - - # batch_num_chunk[batch_idx] = how many chunk in batch idx - batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size - - # cu_num_chunk[batch_idx] = first chunk id of this batch - cu_num_chunk = torch.ones( - batch_num_chunk.numel() + 1, - device=cu_seqlen.device, - dtype=batch_num_chunk.dtype, - ) - cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0) - - # total chunk ( for all batch ) - num_chunk = cu_num_chunk[-1] - - # chunk_sizes[chunk_idx] = chunk_size of chunk idx - chunk_sizes = torch.full( - (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device - ) - chunk_sizes[0] = 0 # for calc cu chunk - batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size - chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size - - # cu_chunk[chunk_idx] = the start chunk offset of chunk idx - cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32) - - # chunk_to_batch[chunk_idx] = batch idx of the chunk idx - chunk_to_batch = torch.zeros( - (num_chunk,), dtype=torch.int32, device=cu_seqlen.device - ) - chunk_to_batch[cu_num_chunk[1:-1]] = 1 - chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32) - - """ filter chunks that need moba attn """ - - # filter chunks ( remove last chunk of each batch ) - # filtered_chunk_indices: chunk index list that excludes the last chunk of each batch - chunk_to_remove = cu_num_chunk[1:] - 1 - chunk_to_remain = torch.ones( - (num_chunk, ), dtype=torch.bool, device=cu_seqlen.device - ) - chunk_to_remain[chunk_to_remove] = False - filtered_chunk_indices = chunk_to_remain.nonzero(as_tuple=True)[0] - num_filtered_chunk = len(filtered_chunk_indices) - - return ( - cu_chunk, - filtered_chunk_indices, - num_filtered_chunk, - filtered_chunk_indices, - chunk_to_batch, - ) \ No newline at end of file diff --git a/minference/ops/xattention_fa.py b/minference/ops/xattention_fa.py index d565adb..28790a2 100644 --- a/minference/ops/xattention_fa.py +++ b/minference/ops/xattention_fa.py @@ -5,7 +5,7 @@ import torch from typing import List, Tuple, Dict, Any -from minference.ops.minference_attn import block_attn_fwd, block_attn_bwd +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd from .op_utils.xattn_utils import ( LN2, find_blocks_chunked, flat_group_gemm_fuse_reshape, softmax_fuse_block_sum ) diff --git a/mtraining/attn_funcs/minfer_func.py b/mtraining/attn_funcs/minfer_func.py index 451c239..1e58660 100644 --- a/mtraining/attn_funcs/minfer_func.py +++ b/mtraining/attn_funcs/minfer_func.py @@ -18,8 +18,8 @@ from nnscaler.ir import IRTensor from nnscaler.ir.operator import IRFwOperation -from minference.ops.minference_attn import minference_flash_attn_func -from minference.ops.minference_attn_triton import minference_flash_attn_triton_func +from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func +from minference.ops.pit_sparse_flash_attention_v3_triton import minference_flash_attn_triton_func from minference.dist_ops import ( minfer_stripe_func, minfer_stripe_triton_func, minfer_zigzag_func, minfer_dr_stripe_func, minfer_dr_stripe_triton_func, diff --git a/mtraining/experiments/scripts/prolong_data_prepare.sh b/mtraining/experiments/scripts/prolong_data_prepare.sh new file mode 100644 index 0000000..2c4ec7a --- /dev/null +++ b/mtraining/experiments/scripts/prolong_data_prepare.sh @@ -0,0 +1,23 @@ +#!/usr/bin/bash + +# ------------------------------------------ +# Download data +# Prerequisite: sudo apt-get install git-lfs && git lfs install +RAW_DATASET_DIR="/path/to/datasets" +git clone https://huggingface.co/datasets/princeton-nlp/prolong-data-512K $RAW_DATASET_DIR/long-context-524288 +cd $RAW_DATASET_DIR/long-context-524288 +git lfs fetch +git lfs checkout + + +# ------------------------------------------ +# Data Processing +cd /path/to/mtraining +MODEL_ID="Qwen/Qwen2.5-3B" +PROCESSED_DATA_DIR="/path/to/processed_dataset" +torchrun --nproc_per_node=4\ + utils/data_utils/prolong.py \ + --model_id $MODEL_ID \ + --dataset_mix fixed_524288 \ + --dataset_path $RAW_DATASET_DIR/long-context-524288 \ + --save_path $PROCESSED_DATA_DIR diff --git a/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh index 0075484..a5acf3d 100755 --- a/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh +++ b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh @@ -58,7 +58,7 @@ export NUM_EPOCH=0 export CKPT_SAVE_STEP=5 export CKPT_SAVE_EPOCH=0 -export CHECK_RESUME=1 +export CHECK_RESUME=0 if [ "$CHECK_RESUME" -eq 1 ]; then CHECK_RESUME="--check_resume" else diff --git a/mtraining/requirements.txt b/mtraining/requirements.txt index d852338..350ed6f 100644 --- a/mtraining/requirements.txt +++ b/mtraining/requirements.txt @@ -1,18 +1,11 @@ transformers==4.48.0 datasets==2.20.0 tensorboard -scikit-learn -matplotlib -seaborn jieba rouge nltk rouge_score evaluate -triton==3.0.0 -mosaicml-cli==0.5.34 -mosaicml-streaming==0.8.1 -sentencepiece==0.1.99 -tiktoken==0.7.0 -zstandard==0.22.0 +# For Data Preparation +mosaicml-streaming==0.8.1 \ No newline at end of file diff --git a/mtraining/utils/data_utils/prolong.py b/mtraining/utils/data_utils/prolong.py new file mode 100644 index 0000000..83672b9 --- /dev/null +++ b/mtraining/utils/data_utils/prolong.py @@ -0,0 +1,122 @@ +import os +import logging +import argparse + +from tqdm import tqdm +from typing import Dict +from streaming import StreamingDataset +from transformers import PreTrainedTokenizer +from datasets import Dataset, concatenate_datasets + +from mtraining.utils.general import get_tokenizer +# ------------------------------------------------ + +logger = logging.getLogger(__name__) + +LLAMA3_MODEL_ID = "meta-llama/Meta-Llama-3-8B" +LLAMA_TOKENZIER = None + +def tokenize(sample: Dict[str, str], tokenizer: PreTrainedTokenizer, seq_len: int=524288): + text = sample['text'] + for token_k, token_v in LLAMA_TOKENZIER.special_tokens_map.items(): + if token_k in tokenizer.special_tokens_map: + text = text.replace(token_v, tokenizer.special_tokens_map[token_k]) + + input_ids = tokenizer.encode( + text, + add_special_tokens=False, + truncation=True, + max_length=seq_len + ) + return {"input_ids": input_ids, "length": len(input_ids)} + + +DOMAINS = [ + "thestackv1_concat_by_repo-524288@0.15", + "thestackv1_concat_by_repo-65536@0.15", + "book-524288@0.05", + "book-65536@0.25", + "fineweb-edu@0.1", + "fineweb-2023-50@0.1", + "stackexchange@0.04", + "dolmawiki@0.04", + "tuluv2@0.03", + "arxiv@0.03", + "openwebmath@0.03", + "textbooks@0.03", +] +FIXED_512K = [ + "thestackv1_concat_by_repo-524288", + "book-524288" +] + +DOMAIN_MIX_DICT = { + "full": DOMAINS, + "fixed_524288": FIXED_512K +} + +def main(args): + global LLAMA_TOKENZIER + + seq_len = args.sequence_length + LLAMA_TOKENZIER = get_tokenizer(LLAMA3_MODEL_ID) + model_tokenizer = get_tokenizer(args.model_id) + if model_tokenizer.bos_token is None: + model_tokenizer.bos_token = "<|endoftext|>" + + domains = DOMAIN_MIX_DICT[args.dataset_mix] + dataset_paths = [os.path.join(args.dataset_path, domain) for domain in domains] + tokenized_datasets = [] + for idx, dataset_path in enumerate(dataset_paths): + print('-' * 50) + print(f"Processing {domains[idx]} from {dataset_path}...") + texts = [] + dataset = StreamingDataset( + local=dataset_path, + remote=None, + shuffle=False, + batch_size=1, + ) + + for ix, sample in tqdm(enumerate(dataset)): + if ix % args.sample_interval != 0: continue + + sample_input_ids = sample["input_ids"] + sample_splits = [sample_input_ids[i:i+seq_len] for i in range(0, len(sample_input_ids), seq_len)] + for sample_split in sample_splits: + # De-tokenization + text = LLAMA_TOKENZIER.decode(sample_split) + texts.append(text) + + hf_dataset = Dataset.from_dict( + { + "text": texts + } + ) + + tokenized_dataset = hf_dataset.map( + tokenize, + remove_columns=hf_dataset.column_names, + num_proc=64, + fn_kwargs={'tokenizer': model_tokenizer, 'seq_len': seq_len} + ) + tokenized_datasets.append(tokenized_dataset) + + print('-' * 50) + print(f"Concatenating and Saving tokenized datasets to {args.save_path}...") + concat_dataset = concatenate_datasets(tokenized_datasets) + filtered_concat_dataset = concat_dataset.filter(lambda x: x['length'] == seq_len, num_proc=128) + filtered_concat_dataset.save_to_disk(args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, default="microsoft/Phi-3-mini-4k-instruct") + parser.add_argument('--dataset_mix', type=str, default="fixed_524288") + parser.add_argument('--dataset_path', type=str) + parser.add_argument('--save_path', type=str) + parser.add_argument('--sequence_length', type=int, default=524288) + parser.add_argument("--sample_interval", type=int, default=1) + args = parser.parse_args() + + main(args) From aef7399adb3d225637680bc128f5eb5a748dd4b9 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Mon, 16 Jun 2025 07:25:55 +0000 Subject: [PATCH 05/12] move triton condition to common utils --- minference/dist_ops/minfer_striped.py | 2 ++ minference/dist_ops/xattn_zigzag.py | 4 +--- minference/ops/pit_sparse_flash_attention_v3.py | 2 ++ minference/ops/utils.py | 4 ++++ mtraining/attn_funcs/minfer_func.py | 9 +++++---- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/minference/dist_ops/minfer_striped.py b/minference/dist_ops/minfer_striped.py index 76b7cb3..91bb272 100644 --- a/minference/dist_ops/minfer_striped.py +++ b/minference/dist_ops/minfer_striped.py @@ -20,6 +20,8 @@ import block_sparse_attn_cuda # type: ignore from block_sparse_attn.block_sparse_attn_interface import convert_blockmask_row_reverse, convert_blockmask_col_reverse # type: ignore # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + except ModuleNotFoundError as e: + print(f"[Warning] Failed to import block_sparse_attn_cuda: {e}") finally: # Restore original flags for future imports sys.setdlopenflags(original_flags) diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py index c201abc..bb6a80e 100644 --- a/minference/dist_ops/xattn_zigzag.py +++ b/minference/dist_ops/xattn_zigzag.py @@ -11,6 +11,7 @@ shuffle_block_mask_zigzag, ) +from minference.ops.utils import use_triton from minference.ops.op_utils.xattn_utils import LN2, find_blocks_chunked from minference.ops.op_utils.vertical_slash_utils import convert_blockmask from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd @@ -161,9 +162,6 @@ def xattn_zigzag_estimate( simple_masks = torch.cat(simple_mask_list, dim=-2) # (batch_size, head_num, q_local_block_num, k_global_block_num) return attn_sums, simple_masks -def use_triton(): - return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" - def xattn_zigzag_forward( process_group: dist.ProcessGroup, q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] diff --git a/minference/ops/pit_sparse_flash_attention_v3.py b/minference/ops/pit_sparse_flash_attention_v3.py index b198b82..d044f6d 100644 --- a/minference/ops/pit_sparse_flash_attention_v3.py +++ b/minference/ops/pit_sparse_flash_attention_v3.py @@ -19,6 +19,8 @@ import block_sparse_attn_cuda from block_sparse_attn.block_sparse_attn_interface import convert_blockmask_row_reverse, convert_blockmask_col_reverse # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + except ModuleNotFoundError as e: + print(f"[Warning] Failed to import block_sparse_attn_cuda: {e}") finally: # Restore original flags for future imports sys.setdlopenflags(original_flags) diff --git a/minference/ops/utils.py b/minference/ops/utils.py index ebe4d7e..917e527 100644 --- a/minference/ops/utils.py +++ b/minference/ops/utils.py @@ -1,6 +1,10 @@ +import os import torch def set_seed(seed=42): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) + +def use_triton(): + return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" \ No newline at end of file diff --git a/mtraining/attn_funcs/minfer_func.py b/mtraining/attn_funcs/minfer_func.py index 1e58660..32cac5c 100644 --- a/mtraining/attn_funcs/minfer_func.py +++ b/mtraining/attn_funcs/minfer_func.py @@ -18,6 +18,7 @@ from nnscaler.ir import IRTensor from nnscaler.ir.operator import IRFwOperation +from minference.ops.utils import use_triton from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func from minference.ops.pit_sparse_flash_attention_v3_triton import minference_flash_attn_triton_func from minference.dist_ops import ( @@ -45,7 +46,7 @@ def minfer_op( ): v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] - if torch.version.hip is None: + if not use_triton(): attn_output = minference_flash_attn_func( query_states.transpose(1, 2).contiguous(), key_states.transpose(1, 2).contiguous(), @@ -107,7 +108,7 @@ def minfer_stripe_op( v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] - if torch.version.hip is None: + if not use_triton(): attn_output = minfer_stripe_func( query_states.transpose(1, 2).contiguous(), key_states.transpose(1, 2).contiguous(), @@ -172,7 +173,7 @@ def minfer_zigzag_op( v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] - if torch.version.hip is None: + if not use_triton(): attn_output = minfer_zigzag_func( query_states.transpose(1, 2).contiguous(), key_states.transpose(1, 2).contiguous(), @@ -223,7 +224,7 @@ def minfer_dr_stripe_op( v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] - if torch.version.hip is None: + if not use_triton(): attn_output = minfer_dr_stripe_func( query_states.transpose(1, 2).contiguous(), key_states.transpose(1, 2).contiguous(), From b52dcf5247fdc56781da7c5f9cee153cf8f3ca6d Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Wed, 18 Jun 2025 08:06:05 +0000 Subject: [PATCH 06/12] Add testing feature draft --- minference/dist_ops/test/minfer_ring_test.py | 223 ++++++++++++++++ .../dist_ops/test/minfer_ring_test_raw.py | 246 ++++++++++++++++++ .../dist_ops/test/xattn_ring_tes_raw.py | 241 +++++++++++++++++ minference/dist_ops/utils.py | 72 ++--- .../ops/op_utils/vertical_slash_utils.py | 17 -- minference/ops/op_utils/xattn_utils.py | 4 +- minference/ops/utils.py | 28 +- 7 files changed, 762 insertions(+), 69 deletions(-) create mode 100644 minference/dist_ops/test/minfer_ring_test.py create mode 100644 minference/dist_ops/test/minfer_ring_test_raw.py create mode 100644 minference/dist_ops/test/xattn_ring_tes_raw.py diff --git a/minference/dist_ops/test/minfer_ring_test.py b/minference/dist_ops/test/minfer_ring_test.py new file mode 100644 index 0000000..f7b3d21 --- /dev/null +++ b/minference/dist_ops/test/minfer_ring_test.py @@ -0,0 +1,223 @@ +# tests/test_minference_sparse_attention.py +""" +Distributed correctness tests for Minference sparse-attention kernels. + +Run with: + pytest -q -s tests/test_minference_sparse_attention.py +or manually choose GPUs, e.g. + CUDA_VISIBLE_DEVICES=0,1 pytest -q -s … + +The test spawns one process per GPU with torch.multiprocessing, so it does +**not** require `pytest-xdist`. It will be skipped automatically if you have +fewer than two visible CUDA devices. +""" +from __future__ import annotations + +import os +import random +from types import SimpleNamespace +from typing import Callable + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed +from minference.dist_ops.minfer_zigzag import minfer_zigzag_func +from minference.dist_ops.minfer_striped import minfer_stripe_func +from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func +from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 2 + +_ATTENTION_IMPLS: dict[str, Callable] = { + "minfer_zigzag": minfer_zigzag_func, + "minfer_stripe": minfer_stripe_func, + "minfer_dr_stripe": minfer_dr_stripe_func, +} + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, + attn_op_name: str, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + attn_op: Callable = _ATTENTION_IMPLS[attn_op_name] + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = attn_op( + q_local, + k_local, + v_local, + cfg.v_size, + cfg.s_size, + layer_idx=0, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # ----------------- reference: dense Flash-Attention ---------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = minference_flash_attn_func( + q_ref, + k_ref, + v_ref, + cfg.v_size, + cfg.s_size, + causal=True, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + torch.testing.assert_close( + final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" + ) + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=name) + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +@pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") +@pytest.mark.parametrize("seq_len", [512, 4096, 8192]) +@pytest.mark.parametrize("batch_sz", [1]) +@pytest.mark.parametrize("head_dim", [32, 64]) +@pytest.mark.parametrize("sparsity", [0.9, 0.95, 1.]) +@pytest.mark.parametrize("ones", [False, True]) +@pytest.mark.parametrize("num_qo_heads", [2, 4, 4]) +@pytest.mark.parametrize("num_kv_heads", [2, 1, 4]) +@pytest.mark.parametrize("attn_op_name", + ["minfer_zigzag", "minfer_stripe", "minfer_dr_stripe"] +) +def test_sparse_attention_kernels( + seq_len: int, + batch_sz: int, + head_dim: int, + sparsity: float, + ones: bool, + num_qo_heads: int, + num_kv_heads: int, + attn_op_name: str, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + + cfg = SimpleNamespace( + batch_size=batch_sz, + seq_len=seq_len, + head_dim=head_dim, + sparsity=sparsity, + ones=ones, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + ) + # derived sizes used by both candidate and reference kernels + cfg.v_size = [int((1 - cfg.sparsity) * 0.1 * cfg.seq_len)] * cfg.num_qo_heads + cfg.s_size = [int((1 - cfg.sparsity) * 0.2 * cfg.seq_len)] * cfg.num_qo_heads + + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg, attn_op_name), + nprocs=_WORLD_SIZE, + join=True, + ) diff --git a/minference/dist_ops/test/minfer_ring_test_raw.py b/minference/dist_ops/test/minfer_ring_test_raw.py new file mode 100644 index 0000000..f12d478 --- /dev/null +++ b/minference/dist_ops/test/minfer_ring_test_raw.py @@ -0,0 +1,246 @@ +# tests/test_minference_sparse_attention.py +""" +Distributed correctness tests for Minference sparse-attention kernels. + +Run with: + pytest -q -s tests/test_minference_sparse_attention.py +or manually choose GPUs, e.g. + CUDA_VISIBLE_DEVICES=0,1 pytest -q -s … + +The test spawns one process per GPU with torch.multiprocessing, so it does +**not** require `pytest-xdist`. It will be skipped automatically if you have +fewer than two visible CUDA devices. +""" +from __future__ import annotations + +import os +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.dist_ops.minfer_zigzag import minfer_zigzag_func +from minference.dist_ops.minfer_striped import minfer_stripe_func +from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func +from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 2 + +_ATTENTION_IMPLS: dict[str, Callable] = { + "minfer_zigzag": minfer_zigzag_func, + "minfer_stripe": minfer_stripe_func, + "minfer_dr_stripe": minfer_dr_stripe_func, +} + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, + attn_op_name: str, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + attn_op: Callable = _ATTENTION_IMPLS[attn_op_name] + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = attn_op( + q_local, + k_local, + v_local, + cfg.v_size, + cfg.s_size, + layer_idx=0, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # ----------------- reference: dense Flash-Attention ---------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = minference_flash_attn_func( + q_ref, + k_ref, + v_ref, + cfg.v_size, + cfg.s_size, + causal=True, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + check_correctness_by_row( + cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[1], ref_grads[1], "K-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[2], ref_grads[2], "V-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + + torch.testing.assert_close( + final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" + ) + # for got, ref, name in zip( + # grads, + # ref_grads, + # ("Q-grad", "K-grad", "V-grad"), + # ): + # torch.testing.assert_close(got, ref, ATOL=_ATOL, RTOL=_RTOL, msg=name) + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +def test_sparse_attention_kernels( + seq_len: int, + batch_sz: int, + head_dim: int, + sparsity: float, + ones: bool, + num_qo_heads: int, + num_kv_heads: int, + attn_op_name: str, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + + cfg = SimpleNamespace( + batch_size=batch_sz, + seq_len=seq_len, + head_dim=head_dim, + sparsity=sparsity, + ones=ones, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + ) + # derived sizes used by both candidate and reference kernels + cfg.v_size = [int((1 - cfg.sparsity) * 0.1 * cfg.seq_len)] * cfg.num_qo_heads + cfg.s_size = [int((1 - cfg.sparsity) * 0.2 * cfg.seq_len)] * cfg.num_qo_heads + + print(f"=" * 80) + print(f"Testing {attn_op_name} with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg, attn_op_name), + nprocs=_WORLD_SIZE, + join=True, + ) + +if __name__ == "__main__": + # Run the test with default parameters + + test_sparse_attention_kernels( + seq_len=4096, + batch_sz=1, + head_dim=64, + sparsity=0.9, + ones=True, + num_qo_heads=2, + num_kv_heads=2, + attn_op_name="minfer_zigzag", + ) \ No newline at end of file diff --git a/minference/dist_ops/test/xattn_ring_tes_raw.py b/minference/dist_ops/test/xattn_ring_tes_raw.py new file mode 100644 index 0000000..30f485e --- /dev/null +++ b/minference/dist_ops/test/xattn_ring_tes_raw.py @@ -0,0 +1,241 @@ +# tests/test_minference_sparse_attention.py +""" +Distributed correctness tests for Minference sparse-attention kernels. + +Run with: + pytest -q -s tests/test_minference_sparse_attention.py +or manually choose GPUs, e.g. + CUDA_VISIBLE_DEVICES=0,1 pytest -q -s … + +The test spawns one process per GPU with torch.multiprocessing, so it does +**not** require `pytest-xdist`. It will be skipped automatically if you have +fewer than two visible CUDA devices. +""" +from __future__ import annotations + +import os +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func +from minference.ops.xattention_fa import xattn_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-1 +_RTOL = 1e-1 +_WORLD_SIZE = 2 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = xattn_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + xattn_params=cfg.xattn_params, + granularity=128, + ) + print(f"Rank {rank} | out_local shape: {out_local.shape}") + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + single_machine_params = cfg.xattn_params.copy() + single_machine_params["chunk_size"] = cfg.seq_len // _WORLD_SIZE + out_ref = xattn_flash_attn_func( + q_ref, k_ref, v_ref, + head_indices=list(range(cfg.num_qo_heads)), + xattn_params=single_machine_params, + granularity=128, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + check_correctness_by_row( + cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[1], ref_grads[1], "K-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[2], ref_grads[2], "V-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + + # torch.testing.assert_close( + # final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" + # ) + # for got, ref, name in zip( + # grads, + # ref_grads, + # ("Q-grad", "K-grad", "V-grad"), + # ): + # torch.testing.assert_close(got, ref, ATOL=_ATOL, RTOL=_RTOL, msg=name) + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +def test_xattention_kernels( + seq_len: int = 4096, + batch_sz: int = 1, + head_dim: int = 64, + ones: bool = True, + num_qo_heads: int = 2, + num_kv_heads: int = 2, + + stride: int = 16, + threshold: float = 0.9, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + xattn_params = { + "stride": stride, + "norm": 1, + "softmax": True, + "threshold": threshold, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + cfg = SimpleNamespace( + batch_size=batch_sz, + seq_len=seq_len, + head_dim=head_dim, + ones=ones, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + xattn_params=xattn_params, + ) + + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + +if __name__ == "__main__": + # Run the test with default parameters + + test_xattention_kernels( + seq_len=16384, + batch_sz=1, + head_dim=64, + ones=False, + num_qo_heads=2, + num_kv_heads=2, + + stride=16, + threshold=0.9, + ) \ No newline at end of file diff --git a/minference/dist_ops/utils.py b/minference/dist_ops/utils.py index 79282da..213e403 100644 --- a/minference/dist_ops/utils.py +++ b/minference/dist_ops/utils.py @@ -19,6 +19,7 @@ PROCESS_GROUPS: Dict[str, dist.ProcessGroup] = {} + @cache def _get_default_args(func): spec = inspect.getfullargspec(func) @@ -188,17 +189,16 @@ def __init__( zigzag: bool = False, ring_list: Optional[list] = None, ): + self._process_group = process_group self._ops: List[P2POp] = [] - self.rank = dist.get_rank(process_group) - self.world_size = dist.get_world_size(process_group) + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) self._reqs = None - self.process_group = process_group if ring_list is not None: curr_idx = ring_list.index(self.rank) self.send_rank = ring_list[(curr_idx + 1) % len(ring_list)] self.recv_rank = ring_list[(curr_idx - 1 + len(ring_list)) % len(ring_list)] - self.send_first = curr_idx % 2 == 0 elif zigzag: parts = self.world_size // 2 self.ring_list = [] @@ -208,45 +208,9 @@ def __init__( offset = ((dist.get_rank() // self.world_size) * self.world_size) self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset - self.send_first = self.revert_rank % 2 == 0 else: self.send_rank = (self.rank + 1) % self.world_size self.recv_rank = (self.rank - 1) % self.world_size - self.send_first = self.rank % 2 == 0 - - if len(PROCESS_GROUPS) == 0: - self.init_process_groups() - - - if self.send_rank in get_inner_ring(process_group): - outer_rank = get_outer_ring(process_group).index(self.rank) - self._send_group = PROCESS_GROUPS[f'inner-{outer_rank}-{int(self.send_first)}'] - else: - self._send_group = PROCESS_GROUPS[f'outer-{int(self.send_first)}'] - - if self.recv_rank in get_inner_ring(process_group): - outer_rank = get_outer_ring(process_group).index(self.rank) - self._recv_group = PROCESS_GROUPS[f'inner-{outer_rank}-{int(1 - self.send_first)}'] - else: - self._recv_group = PROCESS_GROUPS[f'outer-{int(1 - self.send_first)}'] - - self._send_group = PROCESS_GROUPS[f'inner-0-0'] - self._recv_group = PROCESS_GROUPS[f'inner-0-0'] - - def init_process_groups(self): - global PROCESS_GROUPS - num_nodes = int(os.environ.get("NUM_NODES", 1)) - fast_nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) - # fast_nccl_options.config.max_ctas = 2147483647 - # fast_nccl_options.config.min_ctas = 128 - for node_idx in range(num_nodes): - PROCESS_GROUPS[f'inner-{node_idx}-0'] = dist.new_group(pg_options=fast_nccl_options, use_local_synchronization=True) - PROCESS_GROUPS[f'inner-{node_idx}-1'] = dist.new_group(pg_options=fast_nccl_options, use_local_synchronization=True) - slow_nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) - slow_nccl_options.config.max_ctas = 1 - slow_nccl_options.config.min_ctas = 1 - PROCESS_GROUPS['outer-0'] = dist.new_group(pg_options=slow_nccl_options, use_local_synchronization=True) - PROCESS_GROUPS['outer-1'] = dist.new_group(pg_options=slow_nccl_options, use_local_synchronization=True) def send_recv( self, @@ -260,15 +224,24 @@ def send_recv( else: res = recv_tensor - if self.send_first: - self._reqs.append(dist.isend(to_send, self.send_rank, group=self.process_group)) - self._reqs.append(dist.irecv(res, self.recv_rank, group=self.process_group)) - else: - self._reqs.append(dist.irecv(res, self.recv_rank, group=self.process_group)) - self._reqs.append(dist.isend(to_send, self.send_rank, group=self.process_group)) + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group, + tag=2 * (step_idx * (self.rank + 1)) + fwd, + ) + recv_op = dist.P2POp( + dist.irecv, res, self.recv_rank, group=self._process_group, + tag=2 * (step_idx * (self.rank + 1)) + fwd, + ) + self._ops.append(send_op) + self._ops.append(recv_op) return res + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + def wait(self): if self._reqs is None: raise RuntimeError("wait called before commit") @@ -277,6 +250,7 @@ def wait(self): self._reqs = None self._ops = [] + def send_recv_kv( self, k: torch.Tensor, @@ -284,8 +258,8 @@ def send_recv_kv( k_buffer: Optional[torch.Tensor] = None, v_buffer: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - self._reqs = [] next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() return next_k, next_v def send_recv_kv_offsets( @@ -297,12 +271,12 @@ def send_recv_kv_offsets( v_buffer: Optional[torch.Tensor] = None, kv_seq_offsets_buffer: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - self._reqs = [] next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) next_kv_seq_offsets = self.send_recv(kv_seq_offsets, kv_seq_offsets_buffer) + + self.commit() return next_k, next_v, next_kv_seq_offsets - def shuffle_zigzag_input(to_send: torch.Tensor, dim: int = 1, process_group: dist.ProcessGroup = None): diff --git a/minference/ops/op_utils/vertical_slash_utils.py b/minference/ops/op_utils/vertical_slash_utils.py index 57a8d23..ba04cd5 100644 --- a/minference/ops/op_utils/vertical_slash_utils.py +++ b/minference/ops/op_utils/vertical_slash_utils.py @@ -496,10 +496,6 @@ def calc_index( else: last_q = torch.zeros((batch_size, last_q_size, num_kv_heads, num_qo_heads // num_kv_heads, head_dim), device=q.device, dtype=q.dtype) - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking broadcast last_q from rank={last_rank}", flush=True) dist.broadcast(last_q, src=last_rank, group=group, async_op=False) qk = torch.einsum('bmghd, bngd -> bghmn', last_q, k) * (k.shape[-1] ** -0.5) @@ -530,10 +526,6 @@ def calc_index( gathered_vertical = [torch.empty_like(vertical) for _ in range(world_size)] else: gathered_vertical = None - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking gather vertical to {v_gather_rank}", flush=True) dist.gather(vertical, gathered_vertical, dst=v_gather_rank, group=group, async_op=False) if rank == v_gather_rank: @@ -562,10 +554,6 @@ def calc_index( v_indices = v_indices.sort(dim=-1, descending=False).values else: v_indices = torch.empty((batch_size, num_qo_heads, max_v_size), dtype=torch.int32, device=k.device) - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking broadcast v_indices from rank={v_gather_rank}", flush=True) dist.broadcast(v_indices, src=v_gather_rank, group=group, async_op=False) # async s_gather_rank = 0 @@ -580,11 +568,6 @@ def calc_index( gathered_slash = [torch.empty_like(slash) for _ in range(world_size)] else: gathered_slash = None - - if os.getenv("COMM_DEBUG", False): - # For debugging purposes, print the rank and tensor shapes - rank = dist.get_rank(group) - print(f"Rank {rank} | calc_index | before invoking gather slash to rank=0", flush=True) dist.gather(slash, gathered_slash, dst=s_gather_rank, group=group, async_op=False) if rank == s_gather_rank: diff --git a/minference/ops/op_utils/xattn_utils.py b/minference/ops/op_utils/xattn_utils.py index 9484c8b..760e927 100644 --- a/minference/ops/op_utils/xattn_utils.py +++ b/minference/ops/op_utils/xattn_utils.py @@ -560,8 +560,8 @@ def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, output = torch.empty((batch_size, num_heads, q_len // stride, kv_len // stride), dtype=query_states.dtype, device=query_states.device) BLOCK_M = 128 BLOCK_N = 128 - assert (q_len % (stride * BLOCK_M) == 0) - assert (kv_len % (stride * BLOCK_N) == 0) + assert (q_len % (stride * BLOCK_M) == 0), f"q_len={q_len}, stride={stride}, BLOCK_M={BLOCK_M}" + assert (kv_len % (stride * BLOCK_N) == 0), f"kv_len={kv_len}, stride={stride}, BLOCK_N={BLOCK_N}" grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) flat_group_gemm_fuse_reshape_kernel[grid]( diff --git a/minference/ops/utils.py b/minference/ops/utils.py index 917e527..5b5a0a3 100644 --- a/minference/ops/utils.py +++ b/minference/ops/utils.py @@ -7,4 +7,30 @@ def set_seed(seed=42): torch.cuda.manual_seed_all(seed) def use_triton(): - return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" \ No newline at end of file + return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" + + +def check_correctness_by_row( + seq_len, + tensor_var, + ref_tensor_var, + tensor_name, + ATOL=1e-2, + RTOL=1e-2, +): + if not torch.allclose(tensor_var, ref_tensor_var, atol=ATOL, rtol=RTOL): + for h in range(tensor_var.shape[2]): + for i in range(seq_len): + tensor_var_row = tensor_var[:, i, h] + ref_tensor_var_row = ref_tensor_var[:, i, h] + + if not torch.allclose(tensor_var_row, ref_tensor_var_row, atol=ATOL, rtol=RTOL): + print('-' * 60 + '\n') + print(f"Mismatched {tensor_name} at Head {h}, Row {i}:\n") + print(f"Computed:\n{tensor_var_row}\n") + print(f"Ref:\n{ref_tensor_var_row}\n") + + max_diff = torch.max(torch.abs(tensor_var_row - ref_tensor_var_row)) + print(f"Maximal difference: {max_diff.item()}\n") + else: + print(f"All {tensor_name} values are correct within the specified tolerance.") \ No newline at end of file From 4a8bbc1bd0c2ca4acd24d3af2bb018bb9b4ed959 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Wed, 18 Jun 2025 08:06:36 +0000 Subject: [PATCH 07/12] Clean requirements.txt for mtraining --- mtraining/requirements.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mtraining/requirements.txt b/mtraining/requirements.txt index 350ed6f..f76ab56 100644 --- a/mtraining/requirements.txt +++ b/mtraining/requirements.txt @@ -1,11 +1,6 @@ transformers==4.48.0 datasets==2.20.0 tensorboard -jieba -rouge -nltk -rouge_score -evaluate # For Data Preparation mosaicml-streaming==0.8.1 \ No newline at end of file From 2b56a718dbd7cbdc11b5c7863b0e6907db0398e5 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Wed, 18 Jun 2025 09:44:03 +0000 Subject: [PATCH 08/12] update setup.sh --- mtraining/setup.sh | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mtraining/setup.sh b/mtraining/setup.sh index 957ecbd..923fb32 100755 --- a/mtraining/setup.sh +++ b/mtraining/setup.sh @@ -1,23 +1,24 @@ #!/usr/bin/bash set -e # Exit on first error + BASE_DIR="$(cd "$(dirname "$0")" && pwd)" echo $BASE_DIR PIP="$(which pip)" if command -v nvidia-smi then - # assume base image: amlt-sing/acpt-torch2.3.1-py3.10-cuda12.1-ubuntu22.04 $PIP install ninja cmake wheel pybind11 - $PIP install --no-cache-dir torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121 - $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 $PIP install -r "${BASE_DIR}/requirements.txt" $PIP install git+https://github.com/microsoft/nnscaler.git@2368540417bc3b77b7e714d3f1a0de8a51bb66e8 $PIP install "rotary-emb @ git+https://github.com/Dao-AILab/flash-attention.git@9356a1c0389660d7e231ff3163c1ac17d9e3824a#subdirectory=csrc/rotary" $PIP install "block_sparse_attn @ git+https://github.com/HalberdOfPineapple/flash-attention.git@block-sparse" -elif command -v rocm-smi + $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 + $PIP install torch==2.6.0 torchvision==0.21.0 + $PIP install triton==3.0.0 +elif command -v rocm-smi # TODO: to verify the correctness of dependencies in ROCm environment then $PIP install ninja cmake wheel pybind11 - $PIP install --no-cache-dir --pre torch==2.3.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0 + $PIP install --pre torch==2.3.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0 $PIP install git+https://github.com/OpenAI/triton.git@e192dba#subdirectory=python $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 $PIP install -r "${BASE_DIR}/requirements.txt" @@ -32,4 +33,4 @@ NNSCALER_HOME=$(python -c "import nnscaler; print(nnscaler.__path__[0])") echo "export NNSCALER_HOME=${NNSCALER_HOME}" >> ~/.profile echo "export PYTHONPATH=${NNSCALER_HOME}:\${PYTHONPATH}" >> ~/.profile source ~/.profile -pip install -e . \ No newline at end of file +$PIP install -e $BASE_DIR \ No newline at end of file From 20f5add3fc23242fd45ac5f7e5ffc6a5ae44d448 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Sat, 21 Jun 2025 16:44:45 +0000 Subject: [PATCH 09/12] Passed unit testing for MInfer and XAttention --- minference/dist_ops/minfer_zigzag_comp.py | 403 ++++++++++++++++++ minference/dist_ops/moba_zigzag.py | 139 +++--- minference/dist_ops/test/minfer_ring_test.py | 71 ++- .../dist_ops/test/minfer_ring_test_raw.py | 66 +-- minference/dist_ops/test/moba_ring_test.py | 200 +++++++++ .../dist_ops/test/moba_ring_test_raw.py | 207 +++++++++ minference/dist_ops/test/xattn_ring_test.py | 200 +++++++++ ...ring_tes_raw.py => xattn_ring_test_raw.py} | 52 +-- minference/dist_ops/xattn_zigzag.py | 2 +- minference/ops/moba.py | 65 ++- minference/ops/op_utils/moba_utils.py | 12 + minference/ops/utils.py | 41 +- minference/ops/xattention_fa.py | 135 +++++- mtraining/attn_funcs/moba_func.py | 49 +-- .../scripts/train_qwen_mini_ProLong512K.sh | 3 + 15 files changed, 1431 insertions(+), 214 deletions(-) create mode 100644 minference/dist_ops/minfer_zigzag_comp.py create mode 100644 minference/dist_ops/test/moba_ring_test.py create mode 100644 minference/dist_ops/test/moba_ring_test_raw.py create mode 100644 minference/dist_ops/test/xattn_ring_test.py rename minference/dist_ops/test/{xattn_ring_tes_raw.py => xattn_ring_test_raw.py} (87%) diff --git a/minference/dist_ops/minfer_zigzag_comp.py b/minference/dist_ops/minfer_zigzag_comp.py new file mode 100644 index 0000000..e99b88c --- /dev/null +++ b/minference/dist_ops/minfer_zigzag_comp.py @@ -0,0 +1,403 @@ +import os +import torch +import triton +import torch.distributed as dist + +from typing import List, Tuple, Dict + +from MTraining.ops.ring_attn.core.utils import ( + RingComm, TIMING_LOGGER, + shuffle_zigzag_input, recover_zigzag_output, + single_gather_tensor, shuffle_block_mask_zigzag +) +from MTraining.ops.minfer import ( + block_attn_fwd, block_attn_bwd, bar_attn_fwd, bar_attn_bwd, block_bar_attn_fwd, convert_blockmask, + minference_flash_attn_func, build_index +) + + +def compute_sr_flops( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + step: int, + granularity: int, + q_len: int, + head_dim: int, + fwd: bool=True, +): + num_blocks = triton.cdiv(q_len, granularity) + bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] + + # ------------------------------------- + # Block Compute + total_num_blocks = bh * num_blocks * num_blocks / 2 + num_active_blocks = block_mask_offset.sum(dtype=torch.float32).item() + if step == 0: + num_active_blocks -= bh * num_blocks / 2 + block_ratio = num_active_blocks / total_num_blocks + block_flops = num_active_blocks * (granularity * granularity) * head_dim * 2 * 2 + + # ------------------------------------- + # Bar Compute + bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dtype=torch.float32).item() + bar_ratio = bar_cnt_step / (granularity * total_num_blocks) + bar_flops = bar_cnt_step * granularity * head_dim * 2 * 2 + + # ------------------------------------- + # Sparsity Ratio and FLOPs + sparsity_ratio = 1 - block_ratio - bar_ratio + flops = block_flops + bar_flops + + if not fwd: + flops, block_flops, bar_flops = 2.5 * flops, 2.5 * block_flops, 2.5 * bar_flops + return block_ratio, bar_ratio, sparsity_ratio, block_flops, bar_flops, flops + +def compute_sr_by_heads( + block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + step: int, + granularity: int, + q_len: int, +): + batch_size, num_heads = block_mask_offset.shape[0], block_mask_offset.shape[1] + num_blocks = triton.cdiv(q_len, granularity) + + # ------------------------------------- + # Block Compute + total_num_blocks = batch_size * num_blocks * num_blocks / 2 + total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) + num_active_blocks = block_mask_offset.sum(-1).sum(-1).sum(0, dtype=torch.float32) # [num_qo_heads] + if step == 0: + num_active_blocks -= batch_size * num_blocks / 2 + block_ratio_by_heads = num_active_blocks / total_num_blocks_by_heads + + # ------------------------------------- + # Bar Compute + bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) # [num_qo_heads] + bar_ratio_by_heads = bar_cnt_step / total_num_blocks_by_heads / granularity + + # ------------------------------------- + # Sparsity Ratio + sparsity_ratio_by_heads = 1 - block_ratio_by_heads - bar_ratio_by_heads + return sparsity_ratio_by_heads.detach().cpu().numpy().tolist() + +def minfer_zigzag_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group, zigzag=True) + ring_list = comm.ring_list + ring_index = ring_list.index(comm.rank) + + out, lse = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (ring_index - step) % comm.world_size + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + # lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def minfer_zigzag_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + ring_list = kv_comm.ring_list + ring_index = ring_list.index(kv_comm.rank) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (ring_index - step) % kv_comm.world_size + + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # Bar Mask + step_dq, step_dk, step_dv = bar_attn_bwd( + dout, q, k, v, out, step_dq, step_dk, step_dv, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, + granularity=granularity, + deterministic=False, + step=offset, + ) + + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dq += step_dq + dk += step_dk + dv += step_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class MInferZigzagAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + + # Indexing + TIMING_LOGGER.start('indexing') + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + stripe_transform=False, + zigzag_transform=True, + granularity=granularity, group=group + ) + TIMING_LOGGER.end('indexing') + + # Shuffle + TIMING_LOGGER.start('shfl-fwd-input') + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + TIMING_LOGGER.end('shfl-fwd-input') + + # Compute + TIMING_LOGGER.start('forward') + out, softmax_lse = minfer_zigzag_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, + granularity=granularity, + ) + TIMING_LOGGER.end('forward') + + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # Recover outputs + TIMING_LOGGER.start('shfl-fwd-output') + out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + TIMING_LOGGER.end('shfl-fwd-output') + + # Output and Return + if return_softmax: + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + # Shuffle + TIMING_LOGGER.start('shfl-bwd-input') + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + TIMING_LOGGER.end('shfl-bwd-input') + + # Compute + TIMING_LOGGER.start('backward') + dq, dk, dv = minfer_zigzag_backward( + group, dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, + granularity=granularity, + ) + TIMING_LOGGER.end('backward') + + # Recover + TIMING_LOGGER.start('shfl-bwd-output') + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + TIMING_LOGGER.end('shfl-bwd-output') + + return dq, dk, dv, None, None, None, None, None, None, None + + +def minfer_zigzag_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferZigzagAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_zigzag_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferZigzagAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_zigzag_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return MInferZigzagAttnFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/moba_zigzag.py b/minference/dist_ops/moba_zigzag.py index f3d60e4..1ba8d45 100644 --- a/minference/dist_ops/moba_zigzag.py +++ b/minference/dist_ops/moba_zigzag.py @@ -19,7 +19,8 @@ recover_zigzag_output, get_default_args, ) from minference.ops.op_utils.moba_utils import ( - shuffle_input_all, shuffle_input_only, compute_moba_gate + shuffle_input_all, shuffle_input_only, compute_moba_gate, + tensor_4d_to_3d ) @@ -212,7 +213,7 @@ def moba_zigzag_attn_fwd_step( # ----------------------------------------------------------------------------------- # If no queries need to be computed with the current KV chunk and no causal attention is needed, return None to skip the output update if not causal and moba_q.shape[0] == 0: - return None, None, 0, torch.zeros((num_head,), device=q.device, dtype=torch.float32) + return None, None # ----------------------------------------------------------------------------------- # Processing output and lse @@ -304,7 +305,7 @@ def moba_zigzag_attn_fwd( deterministic=False, ): assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) + comm = RingComm(process_group, zigzag=True) block_seq_len = q.shape[0] // 2 seq_len, num_q_heads, head_dim = q.shape @@ -357,7 +358,10 @@ def fwd_step( k_seq_offsets=kv_seq_offsets, ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, + use_triton_kernel=False, + ) elif step <= comm.revert_rank: k0 = k[:block_seq_len] v0 = v[:block_seq_len] @@ -368,7 +372,10 @@ def fwd_step( ) if block_out is not None: - out, lse = update_out_and_lse(out, lse, block_out, block_lse) + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, + use_triton_kernel=False, + ) else: q1 = q[block_seq_len:] block_out, block_lse = fwd_step( @@ -383,6 +390,7 @@ def fwd_step( block_out, block_lse, slice_=(slice(block_seq_len, None)), + use_triton_kernel=False, ) if step + 1 != comm.world_size: @@ -711,8 +719,8 @@ def moba_zigzag_attn_bwd( ): assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) kv_seq_offsets = torch.clone(seq_offsets) seq_len, num_q_heads, head_dim = q.shape @@ -864,17 +872,20 @@ def forward( return_softmax, group, ): - # print(f"Rank {dist.get_rank()} | forward | q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}") + # Note seq_len here refers to the total sequence length + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + assert seq_lens.min() == seq_lens.max(), "Current implementation of MoBA Zigzag Ring Attention does not support variable sequence lengths within a batch" + seq_len = seq_lens.detach().cpu()[0].item() # all sequences in the batch have the same length if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) assert alibi_slopes is None + # --------------------------- + # Compute gate values before shuffling ( - gate_mask, cu_chunk, - filtered_chunk_indices, - num_filtered_chunk, - chunk_to_batch + gate_mask, cu_chunk, filtered_chunk_indices, + num_filtered_chunk, chunk_to_batch ) = compute_moba_gate( q, k, v, seq_offset, @@ -883,38 +894,40 @@ def forward( moba_topk, ) - # gate_mask needs to be shuffled as it is coupled with q q, seq_offsets, gate_mask = shuffle_input_all( to_send=q, gate_mask=gate_mask, seq_offset=seq_offset, process_group=group ) k = shuffle_input_only(to_send=k, process_group=group) v = shuffle_input_only(to_send=v, process_group=group) - k = k.contiguous() v = v.contiguous() - out, softmax_lse = moba_zigzag_attn_fwd( - group, - q, k, v, - seq_offsets, # sequence offsets for Q - layer_idx, - - gate_mask, cu_chunk, - filtered_chunk_indices, - num_filtered_chunk, - chunk_to_batch, - moba_chunk_size, - moba_topk, - - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded + q_3d, k_3d, v_3d = \ + tensor_4d_to_3d(q), tensor_4d_to_3d(k), tensor_4d_to_3d(v) + + out_3d, softmax_lse = moba_zigzag_attn_fwd( + group, + q_3d, k_3d, v_3d, + seq_offsets, # sequence offsets for Q + layer_idx, + + gate_mask, cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + out = out_3d.reshape(*q.shape) ctx.save_for_backward( q, k, v, out, softmax_lse, seq_offsets, gate_mask, cu_chunk, filtered_chunk_indices, @@ -932,14 +945,13 @@ def forward( ctx.deterministic = deterministic ctx.group = group ctx.layer_idx = layer_idx + ctx.seq_len = seq_len - - out = recover_zigzag_output(out, process_group=group) + out = recover_zigzag_output(out, dim=1, process_group=group) return out if not return_softmax else (out, softmax_lse, None) @staticmethod def backward(ctx, dout, *args): - dout = shuffle_input_only(to_send=dout, process_group=ctx.group) ( q, k, v, out, softmax_lse, # [n_heads, seq_block_len] @@ -948,16 +960,22 @@ def backward(ctx, dout, *args): chunk_to_batch ) = ctx.saved_tensors + q_3d, k_3d, v_3d, out_3d = \ + tensor_4d_to_3d(q), tensor_4d_to_3d(k), tensor_4d_to_3d(v), \ + tensor_4d_to_3d(out) + + dout = shuffle_input_only(to_send=dout, process_group=ctx.group) + dout_3d = tensor_4d_to_3d(dout) + num_filtered_chunk = ctx.num_filtered_chunk moba_chunk_size = ctx.moba_chunk_size moba_topk = ctx.moba_topk - - dq, dk, dv = moba_zigzag_attn_bwd( + + dq_3d, dk_3d, dv_3d = moba_zigzag_attn_bwd( ctx.group, - dout, - q, k, v, - out, - softmax_lse, + dout_3d, + q_3d, k_3d, v_3d, + out_3d, softmax_lse, seq_offsets, ctx.layer_idx, @@ -977,11 +995,13 @@ def backward(ctx, dout, *args): deterministic=ctx.deterministic, ) - dq = recover_zigzag_output(dq, ctx.group) - dk = recover_zigzag_output(dk, ctx.group) - dv = recover_zigzag_output(dv, ctx.group) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + dq, dk, dv = \ + dq_3d.reshape(*q.shape), dk_3d.reshape(*k.shape), dv_3d.reshape(*v.shape) + dq = recover_zigzag_output(dq, dim=1, process_group=ctx.group) + dk = recover_zigzag_output(dk, dim=1, process_group=ctx.group) + dv = recover_zigzag_output(dv, dim=1, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None def moba_zigzag_qkvpacked_func( qkv, @@ -992,7 +1012,7 @@ def moba_zigzag_qkvpacked_func( moba_topk, dropout_p=0.0, softmax_scale=None, - causal=False, + causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False, @@ -1029,7 +1049,7 @@ def moba_zigzag_kvpacked_func( moba_topk, dropout_p=0.0, softmax_scale=None, - causal=False, + causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False, @@ -1057,21 +1077,32 @@ def moba_zigzag_kvpacked_func( def moba_zigzag_func( - q, k, v, - seq_offset: torch.Tensor, + q, k, v, # [batch_size, seq_block_len, n_heads, head_dim] layer_idx: int, - cu_seqlens, + global_seq_len: int, moba_chunk_size, moba_topk, + dropout_p=0.0, softmax_scale=None, - causal=False, + causal=True, window_size=(-1, -1), alibi_slopes=None, deterministic=False, return_attn_probs=False, group=None, ): + batch_size = q.shape[0] + cu_seqlens = torch.cumsum( + torch.tensor([0] + [global_seq_len] * batch_size, device=q.device), + dim=0, + dtype=torch.int32, + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_offset = torch.arange(0, global_seq_len, global_seq_len // world_size)[rank:rank+1] + return MoBAZigzagRingFlashAttnFunc.apply( q, k, v, seq_offset, diff --git a/minference/dist_ops/test/minfer_ring_test.py b/minference/dist_ops/test/minfer_ring_test.py index f7b3d21..ae39257 100644 --- a/minference/dist_ops/test/minfer_ring_test.py +++ b/minference/dist_ops/test/minfer_ring_test.py @@ -1,29 +1,16 @@ -# tests/test_minference_sparse_attention.py -""" -Distributed correctness tests for Minference sparse-attention kernels. - -Run with: - pytest -q -s tests/test_minference_sparse_attention.py -or manually choose GPUs, e.g. - CUDA_VISIBLE_DEVICES=0,1 pytest -q -s … - -The test spawns one process per GPU with torch.multiprocessing, so it does -**not** require `pytest-xdist`. It will be skipped automatically if you have -fewer than two visible CUDA devices. -""" from __future__ import annotations import os +import pytest import random -from types import SimpleNamespace from typing import Callable +from types import SimpleNamespace -import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp -from minference.ops.utils import set_seed +from minference.ops.utils import set_seed, check_correct_rate from minference.dist_ops.minfer_zigzag import minfer_zigzag_func from minference.dist_ops.minfer_striped import minfer_stripe_func from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func @@ -32,7 +19,7 @@ # ------------- constants ------------------------------------------------------ _ATOL = 1e-2 _RTOL = 1e-2 -_WORLD_SIZE = 2 +_WORLD_SIZE = 4 _ATTENTION_IMPLS: dict[str, Callable] = { "minfer_zigzag": minfer_zigzag_func, @@ -49,11 +36,12 @@ def _init_process_group(rank: int, world_size: int, port: str) -> None: "MASTER_PORT": port, "RANK": str(rank), "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), } ) dist.init_process_group("nccl", rank=rank, world_size=world_size) - def _run_worker( rank: int, world_size: int, @@ -74,25 +62,22 @@ def _run_worker( # ----------------- generate identical tensors on every rank -------------- if rank == 0: - rand_or_one = ( - torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) - ) - q = rand_or_one( + q = torch.randn( (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), dtype=dtype, device=device, ) - k = rand_or_one( + k = torch.randn( (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), dtype=dtype, device=device, ) - v = rand_or_one( + v = torch.randn( (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), dtype=dtype, device=device, ) - dout = rand_or_one( + dout = torch.randn( (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), dtype=dtype, device=device, @@ -159,30 +144,27 @@ def _run_worker( ) torch.autograd.backward(out_ref, dout) ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) - + # ----------------- assertions ---------------------------------------- - torch.testing.assert_close( - final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" - ) + assert check_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL),\ + "forward output mismatch" + for got, ref, name in zip( grads, ref_grads, ("Q-grad", "K-grad", "V-grad"), ): - torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=name) - + assert check_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL),\ + f"{name} mismatch" dist.destroy_process_group() - # ------------- pytest entry-point -------------------------------------------- @pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") -@pytest.mark.parametrize("seq_len", [512, 4096, 8192]) +@pytest.mark.parametrize("seq_len", [131072, 262144, 524288]) @pytest.mark.parametrize("batch_sz", [1]) -@pytest.mark.parametrize("head_dim", [32, 64]) -@pytest.mark.parametrize("sparsity", [0.9, 0.95, 1.]) -@pytest.mark.parametrize("ones", [False, True]) -@pytest.mark.parametrize("num_qo_heads", [2, 4, 4]) -@pytest.mark.parametrize("num_kv_heads", [2, 1, 4]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("sparsity", [0.9, 0.95]) +@pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) @pytest.mark.parametrize("attn_op_name", ["minfer_zigzag", "minfer_stripe", "minfer_dr_stripe"] ) @@ -191,9 +173,7 @@ def test_sparse_attention_kernels( batch_sz: int, head_dim: int, sparsity: float, - ones: bool, - num_qo_heads: int, - num_kv_heads: int, + num_qkv_head_pair: tuple[int, int], attn_op_name: str, ): """ @@ -201,20 +181,21 @@ def test_sparse_attention_kernels( both forward pass and input-gradient w.r.t Q/K/V. """ port = str(random.randint(12000, 20000)) - cfg = SimpleNamespace( batch_size=batch_sz, seq_len=seq_len, head_dim=head_dim, sparsity=sparsity, - ones=ones, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], ) # derived sizes used by both candidate and reference kernels cfg.v_size = [int((1 - cfg.sparsity) * 0.1 * cfg.seq_len)] * cfg.num_qo_heads cfg.s_size = [int((1 - cfg.sparsity) * 0.2 * cfg.seq_len)] * cfg.num_qo_heads + print(f"=" * 80) + print(f"Testing {attn_op_name} with configuration:\n{cfg}") + print(f"=" * 80) mp.spawn( _run_worker, args=(_WORLD_SIZE, port, cfg, attn_op_name), diff --git a/minference/dist_ops/test/minfer_ring_test_raw.py b/minference/dist_ops/test/minfer_ring_test_raw.py index f12d478..03b5028 100644 --- a/minference/dist_ops/test/minfer_ring_test_raw.py +++ b/minference/dist_ops/test/minfer_ring_test_raw.py @@ -22,7 +22,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from minference.ops.utils import set_seed, check_correctness_by_row +from minference.ops.utils import set_seed, check_correctness_by_row, check_by_correct_rate from minference.dist_ops.minfer_zigzag import minfer_zigzag_func from minference.dist_ops.minfer_striped import minfer_stripe_func from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func @@ -31,7 +31,7 @@ # ------------- constants ------------------------------------------------------ _ATOL = 1e-2 _RTOL = 1e-2 -_WORLD_SIZE = 2 +_WORLD_SIZE = 4 _ATTENTION_IMPLS: dict[str, Callable] = { "minfer_zigzag": minfer_zigzag_func, @@ -163,30 +163,37 @@ def _run_worker( ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) # ----------------- assertions ---------------------------------------- - check_correctness_by_row( - cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL - ) - check_correctness_by_row( - cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL - ) - check_correctness_by_row( - cfg.seq_len, grads[1], ref_grads[1], "K-grad", - ATOL=_ATOL, RTOL=_RTOL - ) - check_correctness_by_row( - cfg.seq_len, grads[2], ref_grads[2], "V-grad", - ATOL=_ATOL, RTOL=_RTOL - ) - - torch.testing.assert_close( - final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" - ) - # for got, ref, name in zip( - # grads, - # ref_grads, - # ("Q-grad", "K-grad", "V-grad"), + # if check_correctness_by_row( + # cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL # ): - # torch.testing.assert_close(got, ref, ATOL=_ATOL, RTOL=_RTOL, msg=name) + # check_correctness_by_row( + # cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + # ) + # check_correctness_by_row( + # cfg.seq_len, grads[1], ref_grads[1], "K-grad", + # ATOL=_ATOL, RTOL=_RTOL + # ) + # check_correctness_by_row( + # cfg.seq_len, grads[2], ref_grads[2], "V-grad", + # ATOL=_ATOL, RTOL=_RTOL + # ) + if check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL): + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL) + + torch.testing.assert_close( + final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" + ) + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + torch.testing.assert_close(got, ref, ATOL=_ATOL, RTOL=_RTOL, msg=name) dist.destroy_process_group() @@ -233,14 +240,13 @@ def test_sparse_attention_kernels( if __name__ == "__main__": # Run the test with default parameters - test_sparse_attention_kernels( - seq_len=4096, + seq_len=512 * 1024, batch_sz=1, - head_dim=64, + head_dim=128, sparsity=0.9, - ones=True, - num_qo_heads=2, + ones=False, + num_qo_heads=4, num_kv_heads=2, attn_op_name="minfer_zigzag", ) \ No newline at end of file diff --git a/minference/dist_ops/test/moba_ring_test.py b/minference/dist_ops/test/moba_ring_test.py new file mode 100644 index 0000000..f7f0a57 --- /dev/null +++ b/minference/dist_ops/test/moba_ring_test.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import os +import pytest +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func +from minference.ops.xattention_fa import xattn_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-1 +_RTOL = 1e-1 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + q = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = xattn_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + xattn_params=cfg.xattn_params, + granularity=128, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + single_machine_params = cfg.xattn_params.copy() + single_machine_params["chunk_size"] = cfg.seq_len // _WORLD_SIZE + out_ref = xattn_flash_attn_func( + q_ref, k_ref, v_ref, + head_indices=list(range(cfg.num_qo_heads)), + xattn_params=single_machine_params, + granularity=128, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + torch.testing.assert_close( + final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward output mismatch" + ) + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=f"{name} mismatch") + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +@pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") +@pytest.mark.parametrize("seq_len", [131072, 262144, 524288]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) +@pytest.mark.parametrize("stride", [16, 32]) +@pytest.mark.parametrize("threshold", [0.9, 1.]) +def test_xattention_kernels( + seq_len: int, + head_dim: int, + num_qkv_head_pair: tuple[int, int], + stride: int, + threshold: float, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + + port = str(random.randint(12000, 20000)) + xattn_params = { + "stride": stride, + "norm": 1, + "softmax": True, + "threshold": threshold, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + cfg = SimpleNamespace( + batch_size=1, + seq_len=seq_len, + head_dim=head_dim, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + xattn_params=xattn_params, + ) + + print(f"=" * 80) + print(f"Testing XAttention (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + diff --git a/minference/dist_ops/test/moba_ring_test_raw.py b/minference/dist_ops/test/moba_ring_test_raw.py new file mode 100644 index 0000000..1f7623a --- /dev/null +++ b/minference/dist_ops/test/moba_ring_test_raw.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import os +import random +from typing import Callable, Optional +from types import SimpleNamespace + + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import Tensor + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.ops.moba import moba_attn_varlen, moba_layer, moba_attn_func, moba_attn_varlen_naive +from minference.dist_ops.moba_zigzag import moba_zigzag_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # [0, BLK_SZ, 2 * BLK_SZ, 3 * BLK_SZ, ..., seq_len - 1] + rank = dist.get_rank() + world_size = dist.get_world_size() + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = moba_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = moba_attn_func( + q_ref, k_ref, v_ref, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + if check_correctness_by_row( + cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL + ): + check_correctness_by_row( + cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[1], ref_grads[1], "K-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[2], ref_grads[2], "V-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +def test_moba_kernels( + seq_len: int = 4096, + batch_size: int = 1, + head_dim: int = 64, + ones: bool = True, + num_qkv_head_pair: tuple[int, int]=(2, 2), + moba_chunk_size: int = 512, + moba_topk: int = 8, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + cfg = SimpleNamespace( + batch_size=batch_size, + seq_len=seq_len, + head_dim=head_dim, + ones=ones, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + ) + + print(f"=" * 80) + print(f"Testing MoBA (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + +if __name__ == "__main__": + test_moba_kernels( + seq_len=16384, + batch_size=1, + head_dim=64, + ones=False, + num_qkv_head_pair=(1, 1), + + moba_chunk_size=512, + moba_topk=4, + ) \ No newline at end of file diff --git a/minference/dist_ops/test/xattn_ring_test.py b/minference/dist_ops/test/xattn_ring_test.py new file mode 100644 index 0000000..3ac5359 --- /dev/null +++ b/minference/dist_ops/test/xattn_ring_test.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import os +import pytest +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func +from minference.ops.xattention_fa import xattn_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-1 +_RTOL = 1e-1 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + q = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = xattn_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + xattn_params=cfg.xattn_params, + granularity=128, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + single_machine_params = cfg.xattn_params.copy() + single_machine_params["chunk_size"] = cfg.seq_len // _WORLD_SIZE + out_ref = xattn_flash_attn_func( + q_ref, k_ref, v_ref, + head_indices=list(range(cfg.num_qo_heads)), + xattn_params=single_machine_params, + granularity=128, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + torch.testing.assert_close( + final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward output mismatch" + ) + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=f"{name} mismatch") + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +@pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") +@pytest.mark.parametrize("seq_len", [131072, 262144, 524288]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) +@pytest.mark.parametrize("stride", [16, 32]) +@pytest.mark.parametrize("threshold", [0.9, 0.95]) +def test_xattention_kernels( + seq_len: int, + head_dim: int, + num_qkv_head_pair: tuple[int, int], + stride: int, + threshold: float, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + + port = str(random.randint(12000, 20000)) + xattn_params = { + "stride": stride, + "norm": 1, + "softmax": True, + "threshold": threshold, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + cfg = SimpleNamespace( + batch_size=1, + seq_len=seq_len, + head_dim=head_dim, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + xattn_params=xattn_params, + ) + + print(f"=" * 80) + print(f"Testing XAttention (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + diff --git a/minference/dist_ops/test/xattn_ring_tes_raw.py b/minference/dist_ops/test/xattn_ring_test_raw.py similarity index 87% rename from minference/dist_ops/test/xattn_ring_tes_raw.py rename to minference/dist_ops/test/xattn_ring_test_raw.py index 30f485e..473960f 100644 --- a/minference/dist_ops/test/xattn_ring_tes_raw.py +++ b/minference/dist_ops/test/xattn_ring_test_raw.py @@ -29,7 +29,7 @@ # ------------- constants ------------------------------------------------------ _ATOL = 1e-1 _RTOL = 1e-1 -_WORLD_SIZE = 2 +_WORLD_SIZE = 4 # ------------- helpers -------------------------------------------------------- def _init_process_group(rank: int, world_size: int, port: str) -> None: @@ -119,7 +119,6 @@ def _run_worker( xattn_params=cfg.xattn_params, granularity=128, ) - print(f"Rank {rank} | out_local shape: {out_local.shape}") torch.autograd.backward(out_local, dout_local) # ----------------- gather outputs & grads for reference comparison ------- @@ -151,30 +150,20 @@ def _run_worker( ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) # ----------------- assertions ---------------------------------------- - check_correctness_by_row( + if check_correctness_by_row( cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL - ) - check_correctness_by_row( - cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL - ) - check_correctness_by_row( - cfg.seq_len, grads[1], ref_grads[1], "K-grad", - ATOL=_ATOL, RTOL=_RTOL - ) - check_correctness_by_row( - cfg.seq_len, grads[2], ref_grads[2], "V-grad", - ATOL=_ATOL, RTOL=_RTOL - ) - - # torch.testing.assert_close( - # final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" - # ) - # for got, ref, name in zip( - # grads, - # ref_grads, - # ("Q-grad", "K-grad", "V-grad"), - # ): - # torch.testing.assert_close(got, ref, ATOL=_ATOL, RTOL=_RTOL, msg=name) + ): + check_correctness_by_row( + cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[1], ref_grads[1], "K-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[2], ref_grads[2], "V-grad", + ATOL=_ATOL, RTOL=_RTOL + ) dist.destroy_process_group() @@ -218,6 +207,9 @@ def test_xattention_kernels( xattn_params=xattn_params, ) + print(f"=" * 80) + print(f"Testing XAttention (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) mp.spawn( _run_worker, args=(_WORLD_SIZE, port, cfg), @@ -227,15 +219,13 @@ def test_xattention_kernels( if __name__ == "__main__": # Run the test with default parameters - test_xattention_kernels( - seq_len=16384, + seq_len=512 * 1024, batch_sz=1, head_dim=64, ones=False, - num_qo_heads=2, - num_kv_heads=2, - + num_qo_heads=4, + num_kv_heads=1, stride=16, - threshold=0.9, + threshold=0.95, ) \ No newline at end of file diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py index bb6a80e..47a89b6 100644 --- a/minference/dist_ops/xattn_zigzag.py +++ b/minference/dist_ops/xattn_zigzag.py @@ -206,7 +206,7 @@ def xattn_zigzag_forward( step_idx=step, ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) + out, lse = update_out_and_lse(out, lse, block_out, block_lse, use_triton_kernel=False) if step + 1 != comm.world_size: comm.wait() k, v = next_k, next_v diff --git a/minference/ops/moba.py b/minference/ops/moba.py index fef02e5..3d77394 100644 --- a/minference/ops/moba.py +++ b/minference/ops/moba.py @@ -141,7 +141,7 @@ def forward( ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5) # self attn - _, _, _, _, self_attn_out_sh, self_attn_lse_hs, _, _ = ( + self_attn_out_sh, self_attn_lse_hs, _, _ = ( _flash_attn_varlen_forward( q=q, k=k, @@ -156,7 +156,7 @@ def forward( ) ) - _, _, _, _, moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( q=moba_q, k=moba_kv[:, 0], v=moba_kv[:, 1], @@ -267,16 +267,19 @@ def backward(ctx, d_output, *args): d_output = d_output.contiguous() - dq, dk, dv, _ = _flash_attn_varlen_backward( + dq = torch.zeros_like(q, dtype=q.dtype, device=q.device) + dk = torch.zeros_like(k, dtype=k.dtype, device=k.device) + dv = torch.zeros_like(v, dtype=v.dtype, device=v.device) + _flash_attn_varlen_backward( dout=d_output, q=q, k=k, v=v, out=output, softmax_lse=mixed_attn_vlse_sh.t().contiguous(), - dq=None, - dk=None, - dv=None, + dq=dq, + dk=dk, + dv=dv, cu_seqlens_q=self_attn_cu_seqlen, cu_seqlens_k=self_attn_cu_seqlen, max_seqlen_q=max_seqlen, @@ -284,7 +287,8 @@ def backward(ctx, d_output, *args): softmax_scale=softmax_scale, causal=True, dropout_p=0.0, - window_size=(-1, -1), + window_size_left=-1, + window_size_right=-1, softcap=0.0, alibi_slopes=None, deterministic=True, @@ -302,16 +306,19 @@ def backward(ctx, d_output, *args): mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1) ) - dmq, dmk, dmv, _ = _flash_attn_varlen_backward( + dmq = torch.zeros_like(moba_q, dtype=moba_q.dtype, device=moba_q.device) + dmk = torch.zeros_like(moba_kv[:, 0], dtype=moba_kv.dtype, device=moba_kv.device) + dmv = torch.zeros_like(moba_kv[:, 1], dtype=moba_kv.dtype, device=moba_kv.device) + _flash_attn_varlen_backward( dout=d_moba_output, q=moba_q, k=moba_kv[:, 0], v=moba_kv[:, 1], out=moba_output, softmax_lse=mixed_attn_vlse, - dq=None, - dk=None, - dv=None, + dq=dmq, + dk=dmk, + dv=dmv, cu_seqlens_q=moba_cu_seqlen_q, cu_seqlens_k=moba_cu_seqlen_kv, max_seqlen_q=max_seqlen, @@ -319,7 +326,8 @@ def backward(ctx, d_output, *args): softmax_scale=softmax_scale, causal=False, dropout_p=0.0, - window_size=(-1, -1), + window_size_left=-1, + window_size_right=-1, softcap=0.0, alibi_slopes=None, deterministic=True, @@ -358,7 +366,7 @@ def moba_attn_varlen( Returns: attn_output (torch.Tensor): [seqlen, head, head_dim] """ - print(f"moba_attn_varlen | cu_seqlens: {cu_seqlens}, max_seqlen: {max_seqlen}, moba_chunk_size: {moba_chunk_size}, moba_topk: {moba_topk}, return_lse: {return_lse}") + # --------------------------------------------------------------------------------------------- kv = torch.stack((k, v), dim=1) # stack along a new dimension -> [S, 2, H, D] @@ -371,7 +379,6 @@ def moba_attn_varlen( cu_chunk, filtered_chunk_indices, num_filtered_chunk, - filtered_chunk_indices, chunk_to_batch, ) = calc_chunks(cu_seqlens, moba_chunk_size) @@ -489,7 +496,6 @@ def moba_attn_varlen( ).to(torch.int32) # ----------------------------------------------- - print(f"filtered_kv shape: {filtered_kv.shape}") moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # here `x` only stands for a dimension (stack dimension for KV) moba_kv = moba_kv.split(moba_chunk_size, dim=1) @@ -533,7 +539,36 @@ def moba_attn_varlen( ) +def moba_attn_func( + q: torch.Tensor, # [batch, q_len, q_heads, head_dim] + k: torch.Tensor, + v: torch.Tensor, + global_seq_len: int, + moba_chunk_size: int, + moba_topk: int, + **kwargs, +): + batch_size = q.shape[0] + cu_seqlens = torch.cumsum( + torch.tensor([0] + [global_seq_len] * batch_size, device=q.device), + dim=0, + dtype=torch.int32, + ) + + q_3d, k_3d, v_3d = \ + q.reshape(-1, q.shape[2], q.shape[3]), \ + k.reshape(-1, k.shape[2], k.shape[3]), \ + v.reshape(-1, v.shape[2], v.shape[3]) + # output: [batch_size, global_seq_len, q_heads, head_dim] + return moba_attn_varlen( + q_3d, k_3d, v_3d, + cu_seqlens, + global_seq_len, + moba_chunk_size, + moba_topk, + ).view(q.shape) + def moba_layer( moba_impl: Callable, diff --git a/minference/ops/op_utils/moba_utils.py b/minference/ops/op_utils/moba_utils.py index 99d73f6..d829a83 100644 --- a/minference/ops/op_utils/moba_utils.py +++ b/minference/ops/op_utils/moba_utils.py @@ -6,6 +6,13 @@ from functools import lru_cache from dataclasses import dataclass +def tensor_4d_to_3d(tensor: torch.Tensor) -> torch.Tensor: + """Convert a 4D tensor to a 3D tensor by collapsing the first two dimensions.""" + if tensor.ndim != 4: + raise ValueError("Input tensor must be 4D.") + return tensor.reshape(tensor.shape[0] * tensor.shape[1], tensor.shape[2], tensor.shape[3]) + + @dataclass class MoBAConfig: moba_chunk_size: int @@ -254,6 +261,11 @@ def compute_moba_gate( moba_chunk_size: int, moba_topk: int, ): + if len(q.shape) == 4: + q, k, v = \ + tensor_4d_to_3d(q), \ + tensor_4d_to_3d(k), \ + tensor_4d_to_3d(v) seq_offset: int = seq_offset.detach().cpu().item() seqlen_block, num_head, head_dim = q.shape _, k_num_head, _ = k.shape diff --git a/minference/ops/utils.py b/minference/ops/utils.py index 5b5a0a3..2c43f20 100644 --- a/minference/ops/utils.py +++ b/minference/ops/utils.py @@ -32,5 +32,44 @@ def check_correctness_by_row( max_diff = torch.max(torch.abs(tensor_var_row - ref_tensor_var_row)) print(f"Maximal difference: {max_diff.item()}\n") + return False else: - print(f"All {tensor_name} values are correct within the specified tolerance.") \ No newline at end of file + print(f"All {tensor_name} values are correct within the specified tolerance.") + return True + +def check_correct_rate( + tensor_var, + ref_tensor_var, + ATOL=1e-2, + RTOL=1e-2, +): + assert len(tensor_var.shape) == 4, "Input tensor must be 3D (B, N, H, D)" + assert tensor_var.shape == ref_tensor_var.shape, ( + "Input and reference tensors must have the same shape" + ) + bsz, seq_len, num_heads, _ = tensor_var.shape + + # Boolean mask of element-wise closeness + elem_close = torch.isclose(tensor_var, ref_tensor_var, atol=ATOL, rtol=RTOL) + + # A row “matches” only if *all* its D elements are close + row_matches = elem_close.all(dim=-1) # shape (B, N, H) + + # Count rows that do *not* match + num_mismatching = (~row_matches).sum().item() + num_mismatching_prop = num_mismatching / (bsz * seq_len * num_heads) + return 1 - num_mismatching_prop + +def check_by_correct_rate( + tensor_var, + ref_tensor_var, + ATOL=1e-2, + RTOL=1e-2, + threshold=0.99 +): + """ + Check if the tensor_var is correct by comparing it with ref_tensor_var. + Returns True if the correctness rate is above 0.99, otherwise False. + """ + correctness_rate = check_correct_rate(tensor_var, ref_tensor_var, ATOL, RTOL) + return correctness_rate >= threshold \ No newline at end of file diff --git a/minference/ops/xattention_fa.py b/minference/ops/xattention_fa.py index 28790a2..2383263 100644 --- a/minference/ops/xattention_fa.py +++ b/minference/ops/xattention_fa.py @@ -6,7 +6,7 @@ from typing import List, Tuple, Dict, Any from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd -from .op_utils.xattn_utils import ( +from minference.ops.op_utils.xattn_utils import ( LN2, find_blocks_chunked, flat_group_gemm_fuse_reshape, softmax_fuse_block_sum ) @@ -234,3 +234,136 @@ def xattn_flash_attn_func( return_attn_probs, deterministic, ) + + + +if __name__ == "__main__": + import argparse + from flash_attn import flash_attn_func + from minference.ops.utils import set_seed + + + parser = argparse.ArgumentParser(description="XAttn Test") + parser.add_argument("--use_ones", action="store_true", help="Use ones for q, k, v") + parser.add_argument("--enable_sparse", action="store_true", help="Enable Sparse XAttenion") + parser.add_argument("--test_backward", action="store_true", help="Test backward pass") + parser.add_argument("--seq_len", type=int, default=16384, help="Sequence length") + args = parser.parse_args() + + ATOL, RTOL = 1e-2, 1e-2 + # dtype = torch.bfloat16 + dtype = torch.float16 + device = torch.device(f"cuda:0") + torch.cuda.set_device(device) + set_seed(2025) + + batch_size, seq_len, num_q_heads, head_dim = 1, args.seq_len, 8, 128 + num_kv_heads = 4 + head_indices = list(range(num_q_heads)) + + granularity = 128 + xattn_params = { + "stride": 16, + "norm": 1, + "softmax": True, + "threshold": 0.9 if args.enable_sparse else 1, + "chunk_size": 16384, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + + if args.use_ones: + q = torch.ones((batch_size, seq_len, num_q_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + k = torch.ones((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + v = torch.ones((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + else: + q = torch.randn((batch_size, seq_len, num_q_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + k = torch.randn((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + v = torch.randn((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + + # Clone inputs for reference implementation to ensure separate gradient computation + if args.test_backward: + q_ref = q.clone().detach().requires_grad_(True) + k_ref = k.clone().detach().requires_grad_(True) + v_ref = v.clone().detach().requires_grad_(True) + else: + q_ref, k_ref, v_ref = q, k, v + + out = xattn_flash_attn_func( + q, k, v, + head_indices, + xattn_params, + granularity=granularity, + ) + print(f"out shape: {out.shape}") + + ref_out = flash_attn_func( + q_ref, k_ref, v_ref, + causal=True, + softmax_scale=head_dim ** (-0.5) + ) + + + # Compare out and ref_out + if not torch.allclose(out, ref_out, atol=ATOL, rtol=RTOL): + num_blocks = seq_len // granularity + for i in range(num_blocks): + start = i * granularity + end = (i + 1) * granularity + out_chunk = out[:, start:end, :, :] + ref_out_chunk = ref_out[:, start:end, :, :] + + print('-' * 60) + if not torch.allclose(out_chunk, ref_out_chunk, atol=ATOL, rtol=RTOL): + print(f"Forward Output mismatch at chunk {i}:") + print(f"Forward out_chunk: {out_chunk}") + print(f"Forward ref_out_chunk: {ref_out_chunk}") + else: + print(f"Forward Output match at chunk {i}") + else: + print("Forward Output match") + + + # Backward pass testing + if args.test_backward: + print("\nTesting backward pass...") + + # Create gradient for backward pass + grad_output = torch.randn_like(out) + grad_output_ref = grad_output.clone() + + # Backward pass for custom implementation + out.backward(grad_output) + + # Backward pass for reference implementation + ref_out.backward(grad_output_ref) + + # Compare gradients + print("\nGradient comparison:") + + # Compare q gradients + q_grad_match = torch.allclose(q.grad, q_ref.grad, atol=ATOL, rtol=RTOL) + print(f"q grad match: {q_grad_match}") + if not q_grad_match: + q_diff = (q.grad - q_ref.grad).abs() + print(f"q grad max diff: {q_diff.max().item()}, mean diff: {q_diff.mean().item()}") + + # Compare k gradients + k_grad_match = torch.allclose(k.grad, k_ref.grad, atol=ATOL, rtol=RTOL) + print(f"k grad match: {k_grad_match}") + if not k_grad_match: + k_diff = (k.grad - k_ref.grad).abs() + print(f"k grad max diff: {k_diff.max().item()}, mean diff: {k_diff.mean().item()}") + + # Compare v gradients + v_grad_match = torch.allclose(v.grad, v_ref.grad, atol=ATOL, rtol=RTOL) + print(f"v grad match: {v_grad_match}") + if not v_grad_match: + v_diff = (v.grad - v_ref.grad).abs() + print(f"v grad max diff: {v_diff.max().item()}, mean diff: {v_diff.mean().item()}") + + print(f"\nOverall gradient match: {q_grad_match and k_grad_match and v_grad_match}") diff --git a/mtraining/attn_funcs/moba_func.py b/mtraining/attn_funcs/moba_func.py index 9b68bc5..d89e9e9 100644 --- a/mtraining/attn_funcs/moba_func.py +++ b/mtraining/attn_funcs/moba_func.py @@ -10,8 +10,8 @@ from nnscaler.runtime.device import DeviceGroup from nnscaler.graph.parser.register import register_op +from minference.ops.moba import moba_attn_func from minference.ops.op_utils.moba_utils import MoBAConfig -from minference.ops.moba import moba_attn_varlen, moba_layer from minference.dist_ops.moba_zigzag import moba_zigzag_func def load_moba_config(moba_config_dict: Dict[str, Any]): @@ -30,16 +30,18 @@ def moba_attention_forward( softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + seq_len = query.shape[2] moba_topk, moba_chunk_size = module.moba_topk, module.moba_chunk_size implementation = module.implementation if implementation == "default": return wrapped_moba_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + seq_len, moba_topk, moba_chunk_size, attention_mask, dropout, scaling, sliding_window, softcap, **kwargs ), None else: - seq_len = query.shape[2] + layer_idx = module.layer_idx return wrapped_moba_zigzag_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -52,19 +54,14 @@ def moba_attention_forward( # ------------------------------------------ def wrapped_moba_func( q: Tensor, k: Tensor, v: Tensor, + seq_len: int, moba_topk: int, moba_chunk_size: int, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - sliding_window: Optional[int] = None, - softcap: Optional[float] = None, **kwargs, ): - return moba_layer( - moba_attn_varlen, - moba_chunk_size, moba_topk, - q, k, v, - attention_mask, dropout, scaling, sliding_window, softcap, + return moba_attn_func( + q, k, v, + seq_len, + moba_chunk_size, moba_topk, **kwargs ) @@ -91,33 +88,14 @@ def wrapped_moba_zigzag_func( output = flash_attn_func(query, key, value, 0.0, softmax_scale, True) return output - batch, block_seq_len, q_heads, head_dim = query.shape - assert batch == 1, "Current implementation only supports batch size = 1" - - # [0, BLK_SZ, 2 * BLK_SZ, 3 * BLK_SZ, ..., seq_len - 1] - rank = dist.get_rank() - world_size = dist.get_world_size() - seq_offsets = torch.arange(0, seq_len, seq_len // world_size)[rank:rank+1] - - _, _, kv_heads, _ = key.shape - - query = query.reshape(-1, q_heads, head_dim) # [B * N, H, D] - key = key.reshape(-1, kv_heads, head_dim) - value = value.reshape(-1, kv_heads, head_dim) - - # Assume only one batch or all batches have the same length - cu_seqlens = torch.cumsum( - torch.tensor([0] + [seq_len] * batch, device=query.device), - dim=0, - dtype=torch.int32, - ) + batch_size, block_seq_len, q_heads, head_dim = query.shape + assert batch_size == 1, "Current implementation only supports batch size = 1" local_process_group = DeviceGroup().get_group(process_group) output = moba_zigzag_func( query, key, value, - seq_offsets, layer_idx, - cu_seqlens, + seq_len, moba_chunk_size, moba_topk, dropout, softmax_scale, @@ -128,8 +106,7 @@ def wrapped_moba_zigzag_func( False, # return_softmax, local_process_group, # group ).contiguous() - return output.view(batch, block_seq_len, q_heads, head_dim) - + return output.view(batch_size, block_seq_len, q_heads, head_dim) # -------------------------------------------------- def moba_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: diff --git a/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh index a5acf3d..6943697 100755 --- a/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh +++ b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh @@ -1,8 +1,10 @@ #!/usr/bin/bash +echo $(which pip) i=$(hostname | awk -F'-' '{print $2}') NODE_RANK=$i export NUM_NODES=1 export REUSE_TYPE="match" +export FORCE_TRITON=1 export HF_TRUST_REMOTE_CODE=true export HF_DATASETS_TRUST_REMOTE_CODE=true @@ -72,6 +74,7 @@ mkdir -p $LOG_PATH echo "Logging directed to $LOG_PATH/train.log" export TRACE_STRATEGY="reuse_cache" +echo $(which torchrun) torchrun --nproc_per_node=$GPU_PER_NODE \ --nnodes=$NUM_NODES \ --node_rank=$NODE_RANK \ From db6fc95ae555f718f0c8a41d93cf41fdd65c4b74 Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Sat, 21 Jun 2025 16:45:34 +0000 Subject: [PATCH 10/12] delete draft from mtrain repo --- minference/dist_ops/minfer_zigzag_comp.py | 403 ---------------------- 1 file changed, 403 deletions(-) delete mode 100644 minference/dist_ops/minfer_zigzag_comp.py diff --git a/minference/dist_ops/minfer_zigzag_comp.py b/minference/dist_ops/minfer_zigzag_comp.py deleted file mode 100644 index e99b88c..0000000 --- a/minference/dist_ops/minfer_zigzag_comp.py +++ /dev/null @@ -1,403 +0,0 @@ -import os -import torch -import triton -import torch.distributed as dist - -from typing import List, Tuple, Dict - -from MTraining.ops.ring_attn.core.utils import ( - RingComm, TIMING_LOGGER, - shuffle_zigzag_input, recover_zigzag_output, - single_gather_tensor, shuffle_block_mask_zigzag -) -from MTraining.ops.minfer import ( - block_attn_fwd, block_attn_bwd, bar_attn_fwd, bar_attn_bwd, block_bar_attn_fwd, convert_blockmask, - minference_flash_attn_func, build_index -) - - -def compute_sr_flops( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - step: int, - granularity: int, - q_len: int, - head_dim: int, - fwd: bool=True, -): - num_blocks = triton.cdiv(q_len, granularity) - bh = block_mask_offset.shape[0] * block_mask_offset.shape[1] - - # ------------------------------------- - # Block Compute - total_num_blocks = bh * num_blocks * num_blocks / 2 - num_active_blocks = block_mask_offset.sum(dtype=torch.float32).item() - if step == 0: - num_active_blocks -= bh * num_blocks / 2 - block_ratio = num_active_blocks / total_num_blocks - block_flops = num_active_blocks * (granularity * granularity) * head_dim * 2 * 2 - - # ------------------------------------- - # Bar Compute - bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dtype=torch.float32).item() - bar_ratio = bar_cnt_step / (granularity * total_num_blocks) - bar_flops = bar_cnt_step * granularity * head_dim * 2 * 2 - - # ------------------------------------- - # Sparsity Ratio and FLOPs - sparsity_ratio = 1 - block_ratio - bar_ratio - flops = block_flops + bar_flops - - if not fwd: - flops, block_flops, bar_flops = 2.5 * flops, 2.5 * block_flops, 2.5 * bar_flops - return block_ratio, bar_ratio, sparsity_ratio, block_flops, bar_flops, flops - -def compute_sr_by_heads( - block_mask_offset: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - step: int, - granularity: int, - q_len: int, -): - batch_size, num_heads = block_mask_offset.shape[0], block_mask_offset.shape[1] - num_blocks = triton.cdiv(q_len, granularity) - - # ------------------------------------- - # Block Compute - total_num_blocks = batch_size * num_blocks * num_blocks / 2 - total_num_blocks_by_heads = torch.tensor([total_num_blocks for _ in range(num_heads)], dtype=torch.float32).to(block_mask_offset.device) - num_active_blocks = block_mask_offset.sum(-1).sum(-1).sum(0, dtype=torch.float32) # [num_qo_heads] - if step == 0: - num_active_blocks -= batch_size * num_blocks / 2 - block_ratio_by_heads = num_active_blocks / total_num_blocks_by_heads - - # ------------------------------------- - # Bar Compute - bar_cnt_step = (bar_cnt[..., step + 1] - bar_cnt[..., step]).sum(dim=-1).sum(dim=-1).sum(0, dtype=torch.float32) # [num_qo_heads] - bar_ratio_by_heads = bar_cnt_step / total_num_blocks_by_heads / granularity - - # ------------------------------------- - # Sparsity Ratio - sparsity_ratio_by_heads = 1 - block_ratio_by_heads - bar_ratio_by_heads - return sparsity_ratio_by_heads.detach().cpu().numpy().tolist() - -def minfer_zigzag_forward( - process_group: dist.ProcessGroup, - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - layer_idx: int, - softmax_scale: float, - block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - comm = RingComm(process_group, zigzag=True) - ring_list = comm.ring_list - ring_index = ring_list.index(comm.rank) - - out, lse = None, None - block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k, next_v = comm.send_recv_kv(k, v) - block_causal = step == 0 - offset = (ring_index - step) % comm.world_size - - out, lse = block_bar_attn_fwd( - q, k, v, out, lse, softmax_scale, - bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], - granularity=granularity, - step=offset, - causal=block_causal, - ) - - if step + 1 != comm.world_size: - comm.wait() - k, v = next_k, next_v - - out = out.to(q.dtype) - # lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def minfer_zigzag_backward( - process_group: dist.ProcessGroup, - dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - layer_idx: int, - softmax_scale: float, - block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - kv_comm = RingComm(process_group, zigzag=True) - d_kv_comm = RingComm(process_group, zigzag=True) - ring_list = kv_comm.ring_list - ring_index = ring_list.index(kv_comm.rank) - - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k, next_v = kv_comm.send_recv_kv(k, v) - block_causal = step == 0 - offset = (ring_index - step) % kv_comm.world_size - - # Block Mask - step_dq, step_dk, step_dv = block_attn_bwd( - dout, q, k, v, out, - softmax_lse, softmax_scale, - block_mask[offset], - granularity=granularity, - deterministic=False, - causal=block_causal, - ) - - # Bar Mask - step_dq, step_dk, step_dv = bar_attn_bwd( - dout, q, k, v, out, step_dq, step_dk, step_dv, - softmax_lse, softmax_scale, - bar_idx, bar_cnt, - granularity=granularity, - deterministic=False, - step=offset, - ) - - # Update dQ, dK, dV - if step == 0: - # TODO: check if float32 is necessary - dq = step_dq.to(torch.float32) - dk = step_dk.to(torch.float32) - dv = step_dv.to(torch.float32) - else: - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - dq += step_dq - dk += step_dk - dv += step_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k, v = next_k, next_v - next_dk, next_dv = d_kv_comm.send_recv_kv( - dk, dv, dk_comm_buffer, dv_comm_buffer - ) - - d_kv_comm.wait() - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class MInferZigzagAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_softmax, - group, - ): - if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape - - # Indexing - TIMING_LOGGER.start('indexing') - block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( - q, k, v_size, s_size, num_tokens_local, - stripe_transform=False, - zigzag_transform=True, - granularity=granularity, group=group - ) - TIMING_LOGGER.end('indexing') - - # Shuffle - TIMING_LOGGER.start('shfl-fwd-input') - q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) - k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) - v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) - TIMING_LOGGER.end('shfl-fwd-input') - - # Compute - TIMING_LOGGER.start('forward') - out, softmax_lse = minfer_zigzag_forward( - group, q, k, v, - layer_idx, softmax_scale, - block_mask, bar_idx, bar_cnt, - granularity=granularity, - ) - TIMING_LOGGER.end('forward') - - # Saving tensors for backward - ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) - ctx.softmax_scale = softmax_scale - ctx.granularity = granularity - ctx.group = group - ctx.layer_idx = layer_idx - - # Recover outputs - TIMING_LOGGER.start('shfl-fwd-output') - out = recover_zigzag_output(out, dim=1, process_group=group) - if return_softmax: - softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) - TIMING_LOGGER.end('shfl-fwd-output') - - # Output and Return - if return_softmax: - return (out, softmax_lse, None) - return out - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors - softmax_scale = ctx.softmax_scale - granularity = ctx.granularity - layer_idx = ctx.layer_idx - group = ctx.group - - # Shuffle - TIMING_LOGGER.start('shfl-bwd-input') - dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) - TIMING_LOGGER.end('shfl-bwd-input') - - # Compute - TIMING_LOGGER.start('backward') - dq, dk, dv = minfer_zigzag_backward( - group, dout, q, k, v, out, softmax_lse, - layer_idx, softmax_scale, - block_mask, bar_idx, bar_cnt, - granularity=granularity, - ) - TIMING_LOGGER.end('backward') - - # Recover - TIMING_LOGGER.start('shfl-bwd-output') - dq = recover_zigzag_output(dq, dim=1, process_group=group) - dk = recover_zigzag_output(dk, dim=1, process_group=group) - dv = recover_zigzag_output(dv, dim=1, process_group=group) - TIMING_LOGGER.end('shfl-bwd-output') - - return dq, dk, dv, None, None, None, None, None, None, None - - -def minfer_zigzag_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferZigzagAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) - - -def minfer_zigzag_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferZigzagAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) - - -def minfer_zigzag_func( # the one used for nnscaler training - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -) -> torch.Tensor: - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - - return MInferZigzagAttnFunc.apply( - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) From 908db991072f29cef14b8977454882c77b6ac57d Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Tue, 24 Jun 2025 10:43:05 +0000 Subject: [PATCH 11/12] Basically finish unit testing --- minference/dist_ops/moba_zigzag.py | 11 +-- minference/dist_ops/test/minfer_ring_test.py | 6 +- .../dist_ops/test/minfer_ring_test_raw.py | 24 ----- minference/dist_ops/test/moba_ring_test.py | 96 ++++++++++--------- .../dist_ops/test/moba_ring_test_raw.py | 8 +- minference/ops/moba.py | 10 +- mtraining/attn_funcs/moba_func.py | 23 ++--- mtraining/train.py | 8 +- 8 files changed, 83 insertions(+), 103 deletions(-) diff --git a/minference/dist_ops/moba_zigzag.py b/minference/dist_ops/moba_zigzag.py index 1ba8d45..3f607e0 100644 --- a/minference/dist_ops/moba_zigzag.py +++ b/minference/dist_ops/moba_zigzag.py @@ -411,9 +411,6 @@ def moba_zigzag_attn_bwd_step( q: torch.Tensor, # [blk_S, H, D] k: torch.Tensor, # [blk_S, H, D] v: torch.Tensor, # [blk_S, H, D] - # dq: torch.Tensor, - # dk: torch.Tensor, - # dv: torch.Tensor, # [blk_S, H, D] softmax_lse: torch.Tensor, # [H, blk_S] num_q_blocks: int, @@ -513,14 +510,10 @@ def moba_zigzag_attn_bwd_step( # ----------------------------------------------------------- # select all q that needs moba attn based on the moba_q_indices moba_q = rearrange(q, "s h d -> ( h s ) d") - moba_dq = rearrange(dq, "s h d -> ( h s ) d") - moba_q = moba_q.index_select(0, moba_q_indices) # [ selected_HS, D ] - moba_dq = moba_dq.index_select(0, moba_q_indices) # [ selected_HS, D ] # [ selected_S, 1, D ] (pseudo head dim for flash attn) moba_q = moba_q.unsqueeze(1) - moba_dq = moba_dq.unsqueeze(1) # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q # note that original q has shape (S, H, D) while moba_q_indices is based on (H S) @@ -658,7 +651,8 @@ def moba_zigzag_attn_bwd_step( ) dq.view(-1, q.shape[-1]).index_add_( - 0, moba_q_sh_indices, moba_dq.view(-1, head_dim).to(dq.dtype) + 0, moba_q_sh_indices, + moba_dq_.view(-1, head_dim).to(dq.dtype) ) moba_dkv[:, 0] = moba_dkv[:, 0] + moba_dk_ moba_dkv[:, 1] = moba_dkv[:, 1] + moba_dv_ @@ -1038,7 +1032,6 @@ def moba_zigzag_qkvpacked_func( group, ) - def moba_zigzag_kvpacked_func( q, kv, diff --git a/minference/dist_ops/test/minfer_ring_test.py b/minference/dist_ops/test/minfer_ring_test.py index ae39257..aa78fed 100644 --- a/minference/dist_ops/test/minfer_ring_test.py +++ b/minference/dist_ops/test/minfer_ring_test.py @@ -10,7 +10,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from minference.ops.utils import set_seed, check_correct_rate +from minference.ops.utils import set_seed, check_by_correct_rate from minference.dist_ops.minfer_zigzag import minfer_zigzag_func from minference.dist_ops.minfer_striped import minfer_stripe_func from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func @@ -146,7 +146,7 @@ def _run_worker( ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) # ----------------- assertions ---------------------------------------- - assert check_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL),\ + assert check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL),\ "forward output mismatch" for got, ref, name in zip( @@ -154,7 +154,7 @@ def _run_worker( ref_grads, ("Q-grad", "K-grad", "V-grad"), ): - assert check_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL),\ + assert check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL),\ f"{name} mismatch" dist.destroy_process_group() diff --git a/minference/dist_ops/test/minfer_ring_test_raw.py b/minference/dist_ops/test/minfer_ring_test_raw.py index 03b5028..97f1d76 100644 --- a/minference/dist_ops/test/minfer_ring_test_raw.py +++ b/minference/dist_ops/test/minfer_ring_test_raw.py @@ -163,20 +163,6 @@ def _run_worker( ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) # ----------------- assertions ---------------------------------------- - # if check_correctness_by_row( - # cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL - # ): - # check_correctness_by_row( - # cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL - # ) - # check_correctness_by_row( - # cfg.seq_len, grads[1], ref_grads[1], "K-grad", - # ATOL=_ATOL, RTOL=_RTOL - # ) - # check_correctness_by_row( - # cfg.seq_len, grads[2], ref_grads[2], "V-grad", - # ATOL=_ATOL, RTOL=_RTOL - # ) if check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL): for got, ref, name in zip( grads, @@ -185,16 +171,6 @@ def _run_worker( ): check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL) - torch.testing.assert_close( - final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward mismatch" - ) - for got, ref, name in zip( - grads, - ref_grads, - ("Q-grad", "K-grad", "V-grad"), - ): - torch.testing.assert_close(got, ref, ATOL=_ATOL, RTOL=_RTOL, msg=name) - dist.destroy_process_group() diff --git a/minference/dist_ops/test/moba_ring_test.py b/minference/dist_ops/test/moba_ring_test.py index f7f0a57..9291d79 100644 --- a/minference/dist_ops/test/moba_ring_test.py +++ b/minference/dist_ops/test/moba_ring_test.py @@ -3,23 +3,34 @@ import os import pytest import random +import functools from types import SimpleNamespace -from typing import Callable import torch import torch.distributed as dist import torch.multiprocessing as mp -from minference.ops.utils import set_seed, check_correctness_by_row -from minference.dist_ops.xattn_zigzag import xattn_zigzag_func -from minference.ops.xattention_fa import xattn_flash_attn_func +from minference.ops.utils import set_seed, check_correctness_by_row, check_by_correct_rate +from minference.dist_ops.moba_zigzag import moba_zigzag_func +from minference.ops.moba import moba_attn_func # ------------- constants ------------------------------------------------------ -_ATOL = 1e-1 -_RTOL = 1e-1 +_ATOL = 1e-2 +_RTOL = 1e-2 _WORLD_SIZE = 4 # ------------- helpers -------------------------------------------------------- +def skip_if_cuda_oom(test_func): + """Decorator: convert OutOfMemoryError raised inside the test into a skip.""" + @functools.wraps(test_func) + def _wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except torch.OutOfMemoryError: + torch.cuda.empty_cache() + pytest.skip("skipped because the GPU ran out of memory") + return _wrapper + def _init_process_group(rank: int, world_size: int, port: str) -> None: """Initialise NCCL backend for the current worker.""" os.environ.update( @@ -35,7 +46,6 @@ def _init_process_group(rank: int, world_size: int, port: str) -> None: dist.init_process_group("nccl", rank=rank, world_size=world_size) - def _run_worker( rank: int, world_size: int, @@ -98,11 +108,12 @@ def _run_worker( dout_local = dout[:, sl].clone() # ----------------- forward / backward on the candidate kernel ------------ - out_local = xattn_zigzag_func( - q_local, k_local, v_local, + out_local = moba_zigzag_func( + q_local, k_local, v_local, layer_idx=0, - xattn_params=cfg.xattn_params, - granularity=128, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, ) torch.autograd.backward(out_local, dout_local) @@ -123,43 +134,54 @@ def _run_worker( k_ref = k.detach().clone().requires_grad_() v_ref = v.detach().clone().requires_grad_() - single_machine_params = cfg.xattn_params.copy() - single_machine_params["chunk_size"] = cfg.seq_len // _WORLD_SIZE - out_ref = xattn_flash_attn_func( + out_ref = moba_attn_func( q_ref, k_ref, v_ref, - head_indices=list(range(cfg.num_qo_heads)), - xattn_params=single_machine_params, - granularity=128, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, ) torch.autograd.backward(out_ref, dout) ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) - torch.testing.assert_close( - final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward output mismatch" - ) + # ----------------- assertions ---------------------------------------- + assert check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL),\ + "forward output mismatch" + for got, ref, name in zip( grads, ref_grads, ("Q-grad", "K-grad", "V-grad"), ): - torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=f"{name} mismatch") + assert check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL),\ + f"{name} mismatch" + + # torch.testing.assert_close( + # final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward output mismatch" + # ) + # for got, ref, name in zip( + # grads, + # ref_grads, + # ("Q-grad", "K-grad", "V-grad"), + # ): + # torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=f"{name} mismatch") dist.destroy_process_group() # ------------- pytest entry-point -------------------------------------------- +@skip_if_cuda_oom @pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") -@pytest.mark.parametrize("seq_len", [131072, 262144, 524288]) +@pytest.mark.parametrize("seq_len", [16384, 32768]) @pytest.mark.parametrize("head_dim", [64, 128]) @pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) -@pytest.mark.parametrize("stride", [16, 32]) -@pytest.mark.parametrize("threshold", [0.9, 1.]) -def test_xattention_kernels( +@pytest.mark.parametrize("moba_chunk_size", [128, 256]) +@pytest.mark.parametrize("moba_topk", [8, 16]) +def test_moba_kernels( seq_len: int, head_dim: int, num_qkv_head_pair: tuple[int, int], - stride: int, - threshold: float, + moba_chunk_size: int, + moba_topk: int, ): """ Compare every sparse kernel against the dense Flash-Attention reference on @@ -167,34 +189,22 @@ def test_xattention_kernels( """ port = str(random.randint(12000, 20000)) - xattn_params = { - "stride": stride, - "norm": 1, - "softmax": True, - "threshold": threshold, - "select_mode": "inverse", - "use_triton": True, - "causal": True, - "kdb": 1, - "keep_sink": False, - "keep_recent": False - } cfg = SimpleNamespace( batch_size=1, seq_len=seq_len, head_dim=head_dim, num_qo_heads=num_qkv_head_pair[0], num_kv_heads=num_qkv_head_pair[1], - xattn_params=xattn_params, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, ) print(f"=" * 80) - print(f"Testing XAttention (w. Zigzag) with configuration:\n{cfg}") + print(f"Testing MoBA (w. Zigzag) with configuration:\n{cfg}") print(f"=" * 80) mp.spawn( _run_worker, args=(_WORLD_SIZE, port, cfg), nprocs=_WORLD_SIZE, join=True, - ) - + ) \ No newline at end of file diff --git a/minference/dist_ops/test/moba_ring_test_raw.py b/minference/dist_ops/test/moba_ring_test_raw.py index 1f7623a..4943337 100644 --- a/minference/dist_ops/test/moba_ring_test_raw.py +++ b/minference/dist_ops/test/moba_ring_test_raw.py @@ -198,10 +198,10 @@ def test_moba_kernels( test_moba_kernels( seq_len=16384, batch_size=1, - head_dim=64, + head_dim=128, ones=False, - num_qkv_head_pair=(1, 1), + num_qkv_head_pair=(4, 1), - moba_chunk_size=512, - moba_topk=4, + moba_chunk_size=128, + moba_topk=8, ) \ No newline at end of file diff --git a/minference/ops/moba.py b/minference/ops/moba.py index 3d77394..d9d860a 100644 --- a/minference/ops/moba.py +++ b/minference/ops/moba.py @@ -267,9 +267,7 @@ def backward(ctx, d_output, *args): d_output = d_output.contiguous() - dq = torch.zeros_like(q, dtype=q.dtype, device=q.device) - dk = torch.zeros_like(k, dtype=k.dtype, device=k.device) - dv = torch.zeros_like(v, dtype=v.dtype, device=v.device) + dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) _flash_attn_varlen_backward( dout=d_output, q=q, @@ -366,7 +364,11 @@ def moba_attn_varlen( Returns: attn_output (torch.Tensor): [seqlen, head, head_dim] """ - + head_group_size = q.shape[1] // k.shape[1] + if head_group_size > 1: + k = torch.repeat_interleave(k, head_group_size, dim=1) + v = torch.repeat_interleave(v, head_group_size, dim=1) + # --------------------------------------------------------------------------------------------- kv = torch.stack((k, v), dim=1) # stack along a new dimension -> [S, 2, H, D] diff --git a/mtraining/attn_funcs/moba_func.py b/mtraining/attn_funcs/moba_func.py index d89e9e9..b604dbb 100644 --- a/mtraining/attn_funcs/moba_func.py +++ b/mtraining/attn_funcs/moba_func.py @@ -33,15 +33,14 @@ def moba_attention_forward( seq_len = query.shape[2] moba_topk, moba_chunk_size = module.moba_topk, module.moba_chunk_size implementation = module.implementation + if implementation == "default": return wrapped_moba_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), seq_len, moba_topk, moba_chunk_size, - attention_mask, dropout, scaling, sliding_window, softcap, **kwargs ), None - else: - + elif implementation == "zigzag": layer_idx = module.layer_idx return wrapped_moba_zigzag_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -50,19 +49,21 @@ def moba_attention_forward( layer_idx, attention_mask, dropout, scaling, sliding_window, softcap, ), None + else: + raise ValueError(f"Unsupported MoBA implementation: {implementation}. " + f"Supported implementations are 'default' and 'zigzag'.") + # ------------------------------------------ def wrapped_moba_func( q: Tensor, k: Tensor, v: Tensor, seq_len: int, moba_topk: int, moba_chunk_size: int, - **kwargs, ): return moba_attn_func( q, k, v, seq_len, moba_chunk_size, moba_topk, - **kwargs ) def wrapped_moba_zigzag_func( @@ -110,16 +111,16 @@ def wrapped_moba_zigzag_func( # -------------------------------------------------- def moba_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: - if query_states.shape[1] != key_states.shape[1]: - assert query_states.shape[1] % key_states.shape[1] == 0 - group_size = query_states.shape[1] // key_states.shape[1] - assert query_states.shape[1] == value_states.shape[1] * group_size + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size q_anno = f'(group_num {group_size})' kv_anno = 'group_num' else: q_anno = kv_anno = 'num_heads' - return f'b {q_anno} l^ hd^, b {kv_anno} s^ hd^, b {kv_anno} s^ vd^, {q_anno} -> b l^ {q_anno} vd^' + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' def moba_zigzag_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: num_q_heads, num_kv_heads = query_states.shape[2], key_states.shape[2] @@ -135,6 +136,7 @@ def moba_zigzag_attn_anno(query_states, key_states, value_states, *args, **kwarg attn_anno = f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' return attn_anno + def emit_moba_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: """Special rule to generate zigzag_attn node""" @@ -170,7 +172,6 @@ def emit_moba_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str args = ", ".join(list(args) + kw_pairs) return f"{signature}({args})" - if __name__ != "__main__": register_op(moba_attn_anno)(wrapped_moba_func) diff --git a/mtraining/train.py b/mtraining/train.py index 5075f54..dfa4577 100644 --- a/mtraining/train.py +++ b/mtraining/train.py @@ -198,12 +198,9 @@ def __init__( config_path=config_path, **kwargs, ) - from minference.ops.op_utils.moba_utils import MoBAConfig - # -------------------------------------------- - print(f"MoBAConfig: {moba_config_dict}") - moba_config = MoBAConfig(**moba_config_dict) - moba_topk, moba_chunk_size = moba_config.moba_topk, moba_config.moba_chunk_size + moba_topk, moba_chunk_size = moba_config_dict["moba_topk"], moba_config_dict["moba_chunk_size"] + moba_implementation = moba_config_dict.get("implementation", 'default') # -------------------------------------------- # We still need to attach the function object to the model @@ -214,6 +211,7 @@ def update_module(m): if isinstance(m, Attention): m.moba_topk = moba_topk m.moba_chunk_size = moba_chunk_size + m.implementation = moba_implementation self.model.apply(update_module) ATTN_TO_MODEL = { From f74982d3f46d32d1a421a102c187b6ba2da434cd Mon Sep 17 00:00:00 2001 From: Wenxuan Li Date: Wed, 25 Jun 2025 06:25:23 +0000 Subject: [PATCH 12/12] Merged CUDA and Triton-based implementation --- minference/dist_ops/__init__.py | 4 - .../dist_ops/minfer_dr_stripe_triton.py | 404 ------ minference/dist_ops/minfer_dr_striped.py | 406 +++++- minference/dist_ops/minfer_striped.py | 298 +++- minference/dist_ops/minfer_striped_triton.py | 284 ---- minference/dist_ops/minfer_zigzag.py | 6 +- minference/dist_ops/test/minfer_ring_test.py | 10 + .../dist_ops/test/minfer_ring_test_raw.py | 2 + minference/dist_ops/xattn_zigzag.py | 3 +- .../ops/op_utils/vertical_slash_utils.py | 35 - .../ops/pit_sparse_flash_attention_v3.py | 1243 +++++++++++++---- .../pit_sparse_flash_attention_v3_triton.py | 1076 -------------- mtraining/attn_funcs/minfer_func.py | 106 +- 13 files changed, 1607 insertions(+), 2270 deletions(-) delete mode 100644 minference/dist_ops/minfer_dr_stripe_triton.py delete mode 100644 minference/dist_ops/minfer_striped_triton.py delete mode 100644 minference/ops/pit_sparse_flash_attention_v3_triton.py diff --git a/minference/dist_ops/__init__.py b/minference/dist_ops/__init__.py index cc82224..95e6cd5 100644 --- a/minference/dist_ops/__init__.py +++ b/minference/dist_ops/__init__.py @@ -1,9 +1,5 @@ from .minfer_striped import minfer_stripe_func from .minfer_zigzag import minfer_zigzag_func from .minfer_dr_striped import minfer_dr_stripe_func - -from .minfer_striped_triton import minfer_stripe_triton_func -from .minfer_dr_stripe_triton import minfer_dr_stripe_triton_func - from .moba_zigzag import moba_zigzag_func from .xattn_zigzag import xattn_zigzag_func diff --git a/minference/dist_ops/minfer_dr_stripe_triton.py b/minference/dist_ops/minfer_dr_stripe_triton.py deleted file mode 100644 index 7c8ea3c..0000000 --- a/minference/dist_ops/minfer_dr_stripe_triton.py +++ /dev/null @@ -1,404 +0,0 @@ -import os -import torch -import torch.distributed as dist - -from typing import List, Tuple - -from .utils import ( - RingComm, - shuffle_striped_input, recover_striped_output, - get_inner_ring, get_outer_ring -) -from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask -from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd, block_bar_attn_bwd - -def minfer_dr_stripe_triton_forward_inner( - process_group: dist.ProcessGroup, - outer_step: int, - outer_offset: int, - inner_ring: List[int], - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - softmax_scale: float, - block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - inner_comm = RingComm(process_group, False, inner_ring) - inner_rank = inner_ring.index(inner_comm.rank) - num_inner_steps = len(inner_ring) - - next_k, next_v = None, None - - for inner_step in range(num_inner_steps): - if inner_step + 1 != num_inner_steps: - next_k, next_v = inner_comm.send_recv_kv(k, v) - - block_causal = (outer_step == 0) and (inner_step == 0) - offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps - - out, lse = block_bar_attn_fwd( - q, k, v, out, lse, softmax_scale, - bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], - granularity=granularity, - step=offset, - causal=block_causal, - ) - - if inner_step + 1 != num_inner_steps: - inner_comm.wait() - k, v = next_k, next_v - - return out, lse - - -def minfer_dr_stripe_triton_forward_outer( - process_group: dist.ProcessGroup, - outer_ring: List[int], - inner_ring: List[int], - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - softmax_scale: float, - block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - outer_comm = RingComm(process_group, False, outer_ring) - outer_rank = outer_ring.index(outer_comm.rank) - num_outer_steps = len(outer_ring) - - out = None - lse = None - - next_k, next_v = None, None - for outer_step in range(num_outer_steps): - if outer_step + 1 != num_outer_steps: - next_k, next_v = outer_comm.send_recv_kv(k, v) - - outer_offset = (outer_rank - outer_step) % num_outer_steps - out, lse = minfer_dr_stripe_triton_forward_inner( - process_group, outer_step, outer_offset, inner_ring, - q, k, v, out, lse, softmax_scale, - block_idx, block_cnt, bar_idx, bar_cnt, - granularity, - ) - - if outer_step + 1 != num_outer_steps: - outer_comm.wait() - k, v = next_k, next_v - - # out = out.to(q.dtype) - # lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def minfer_dr_stripe_triton_backward_inner( - process_group: dist.ProcessGroup, - outer_step: int, - outer_offset: int, - inner_ring: List[int], - dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - softmax_scale: float, - block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - inner_kv_comm = RingComm(process_group, False, inner_ring) - inner_d_kv_comm = RingComm(process_group, False, inner_ring) - inner_rank = inner_ring.index(inner_kv_comm.rank) - num_inner_steps = len(inner_ring) - - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - for inner_step in range(num_inner_steps): - if inner_step + 1 != num_inner_steps: - next_k, next_v = inner_kv_comm.send_recv_kv(k, v) - - block_causal = (outer_step == 0) and (inner_step == 0) - offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps - - dq, step_dk, step_dv = block_bar_attn_bwd( - dout, q, k, v, out, dq, None, None, - softmax_lse, softmax_scale, - bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], - granularity=granularity, - deterministic=False, - step=offset, - causal=block_causal, - ) - - # Update dQ, dK, dV - if inner_step == 0: - # TODO: check if float32 is necessary - dk = step_dk.to(torch.float32) - dv = step_dv.to(torch.float32) - else: - inner_d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - dk += step_dk - dv += step_dv - - if inner_step + 1 != num_inner_steps: - inner_kv_comm.wait() - k, v = next_k, next_v - - next_dk, next_dv = inner_d_kv_comm.send_recv_kv( - dk, dv, dk_comm_buffer, dv_comm_buffer - ) - - inner_d_kv_comm.wait() - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -def minfer_dr_stripe_triton_backward_outer( - process_group: dist.ProcessGroup, - outer_ring: List[int], - inner_ring: List[int], - dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - softmax_scale: float, - block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - outer_kv_comm = RingComm(process_group, False, outer_ring) - outer_d_kv_comm = RingComm(process_group, False, outer_ring) - outer_rank = outer_ring.index(outer_kv_comm.rank) - num_outer_steps = len(outer_ring) - - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - for outer_step in range(num_outer_steps): - if outer_step + 1 != num_outer_steps: - next_k, next_v = outer_kv_comm.send_recv_kv(k, v) - - outer_offset = (outer_rank - outer_step) % num_outer_steps - step_dq, step_dk, step_dv = minfer_dr_stripe_triton_backward_inner( - process_group, outer_step, outer_offset, inner_ring, - dout, q, k, v, out, softmax_lse, softmax_scale, - block_idx, block_cnt, bar_idx, bar_cnt, granularity, - ) - - if outer_step == 0: - # TODO: check if float32 is necessary - dq = step_dq.to(torch.float32) - dk = step_dk.to(torch.float32) - dv = step_dv.to(torch.float32) - else: - dq += step_dq - outer_d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - dk += step_dk - dv += step_dv - - if outer_step + 1 != num_outer_steps: - outer_kv_comm.wait() - k, v = next_k, next_v - - next_dk, next_dv = outer_d_kv_comm.send_recv_kv( - dk, dv, dk_comm_buffer, dv_comm_buffer - ) - - outer_d_kv_comm.wait() - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class MInferDRStripeTritonFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_softmax, - group, - ): - batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape - if softmax_scale is None: - softmax_scale = head_dim ** (-0.5) - - # build index TODO: move convert_indices() into the first step - block_mask, bar_idx, bar_cnt = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) - block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) - - # TODO: remove shuffle - q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) - k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) - v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) - - inner_ring = get_inner_ring(group) - outer_ring = get_outer_ring(group) - out, softmax_lse = minfer_dr_stripe_triton_forward_outer( - group, outer_ring, inner_ring, - q, k, v, softmax_scale, - block_idx, block_cnt, bar_idx, bar_cnt, granularity, - ) - - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) - ctx.softmax_scale = softmax_scale - ctx.granularity = granularity - ctx.group = group - ctx.inner_ring = inner_ring - ctx.outer_ring = outer_ring - ctx.layer_idx = layer_idx - - out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) - if return_softmax: - softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) - return (out, softmax_lse, None) - return out - - @staticmethod - def backward(ctx, dout, *args): - dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) - q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors - - dq, dk, dv = minfer_dr_stripe_triton_backward_outer( - ctx.group, ctx.outer_ring, ctx.inner_ring, - dout, q, k, v, out, softmax_lse, ctx.softmax_scale, - block_idx, block_cnt, bar_idx, bar_cnt, ctx.granularity, - ) - dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) - dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) - dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) - return dq, dk, dv, None, None, None, None, None, None, None - - -def minfer_dr_stripe_triton_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferDRStripeTritonFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) - - -def minfer_dr_stripe_triton_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferDRStripeTritonFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) - - -def minfer_dr_stripe_triton_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferDRStripeTritonFunc.apply( - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) diff --git a/minference/dist_ops/minfer_dr_striped.py b/minference/dist_ops/minfer_dr_striped.py index ba8e59e..d5fb346 100644 --- a/minference/dist_ops/minfer_dr_striped.py +++ b/minference/dist_ops/minfer_dr_striped.py @@ -8,12 +8,12 @@ shuffle_striped_input, recover_striped_output, get_inner_ring, get_outer_ring ) - -from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd -from minference.ops.pit_sparse_flash_attention_v3 import block_attn_bwd, bar_attn_bwd +from minference.ops.utils import use_triton from minference.ops.op_utils.vertical_slash_utils import build_index, extract_kv, merge_kv, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3 import block_bar_attn_fwd, block_attn_bwd, bar_attn_bwd, block_bar_attn_bwd - +# ------------------------------------------------------------------------ +# CUDA-based Implementation (Block-Sparse-Attention version) def minfer_dr_stripe_forward_inner( process_group: dist.ProcessGroup, outer_step: int, @@ -271,6 +271,230 @@ def minfer_dr_stripe_backward_outer( return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +# --------------------------------------------------------------------- +# Purely Triton-based Implementation +def minfer_dr_stripe_triton_forward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + inner_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_comm.rank) + num_inner_steps = len(inner_ring) + + next_k, next_v = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_dr_stripe_triton_forward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + outer_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_comm.rank) + num_outer_steps = len(outer_ring) + + out = None + lse = None + + next_k, next_v = None, None + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + + outer_offset = (outer_rank - outer_step) % num_outer_steps + out, lse = minfer_dr_stripe_triton_forward_inner( + process_group, outer_step, outer_offset, inner_ring, + q, k, v, out, lse, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity, + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + # out = out.to(q.dtype) + # lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def minfer_dr_stripe_triton_backward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + inner_kv_comm = RingComm(process_group, False, inner_ring) + inner_d_kv_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_kv_comm.rank) + num_inner_steps = len(inner_ring) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_kv_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + dq, step_dk, step_dv = block_bar_attn_bwd( + dout, q, k, v, out, dq, None, None, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + deterministic=False, + step=offset, + causal=block_causal, + ) + + # Update dQ, dK, dV + if inner_step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + inner_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + if inner_step + 1 != num_inner_steps: + inner_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = inner_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + inner_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +def minfer_dr_stripe_triton_backward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + outer_kv_comm = RingComm(process_group, False, outer_ring) + outer_d_kv_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_kv_comm.rank) + num_outer_steps = len(outer_ring) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + + outer_offset = (outer_rank - outer_step) % num_outer_steps + step_dq, step_dk, step_dv = minfer_dr_stripe_triton_backward_inner( + process_group, outer_step, outer_offset, inner_ring, + dout, q, k, v, out, softmax_lse, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, granularity, + ) + + if outer_step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + dq += step_dq + outer_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +# ----------------------------------------------------------- +# Attention Classes class MInferDRStripeFunc(torch.autograd.Function): @staticmethod def forward( @@ -365,76 +589,74 @@ def backward(ctx, dout, *args): return dq, dk, dv, None, None, None, None, None, None, None - -def minfer_dr_stripe_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferDRStripeFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], +class MInferDRStripeTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, v_size, s_size, layer_idx, softmax_scale, granularity, - return_attn_probs, + return_softmax, group, - ) + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + # build index TODO: move convert_indices() into the first step + block_mask, bar_idx, bar_cnt, _, _, _ = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) -def minfer_dr_stripe_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferDRStripeFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) + # TODO: remove shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + inner_ring = get_inner_ring(group) + outer_ring = get_outer_ring(group) + out, softmax_lse = minfer_dr_stripe_triton_forward_outer( + group, outer_ring, inner_ring, + q, k, v, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, granularity, + ) + + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.inner_ring = inner_ring + ctx.outer_ring = outer_ring + ctx.layer_idx = layer_idx + + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + dq, dk, dv = minfer_dr_stripe_triton_backward_outer( + ctx.group, ctx.outer_ring, ctx.inner_ring, + dout, q, k, v, out, softmax_lse, ctx.softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, ctx.granularity, + ) + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None +# --------------------------------------------------------------------- +# Wrapped Attention Functions def minfer_dr_stripe_func( q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] @@ -457,15 +679,49 @@ def minfer_dr_stripe_func( assert window_size == (-1, -1) assert alibi_slopes is None assert not deterministic - return MInferDRStripeFunc.apply( - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, + + if not use_triton(): + return MInferDRStripeFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + else: + print(f"Rank {dist.get_rank()} | minfer_dr_stripe_func | Using Triton implementation for MTraining") + return MInferDRStripeTritonFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_dr_stripe_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs, +): + return minfer_dr_stripe_func( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + *args, **kwargs ) + +def minfer_dr_stripe_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + *args, **kwargs, +): + return minfer_dr_stripe_func( + q, + kv[:, :, 0], + kv[:, :, 1], + *args, **kwargs + ) \ No newline at end of file diff --git a/minference/dist_ops/minfer_striped.py b/minference/dist_ops/minfer_striped.py index 91bb272..9225d24 100644 --- a/minference/dist_ops/minfer_striped.py +++ b/minference/dist_ops/minfer_striped.py @@ -5,14 +5,17 @@ import torch.distributed as dist from typing import List, Tuple, Dict -from .utils import ( - RingComm, - shuffle_striped_input, recover_striped_output, +from minference.ops.utils import use_triton +from minference.dist_ops.utils import ( + RingComm, shuffle_striped_input, recover_striped_output, +) +from minference.ops.pit_sparse_flash_attention_v3 import ( + block_bar_attn_fwd, block_attn_bwd, bar_attn_bwd, block_bar_attn_bwd ) -from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd -from minference.ops.pit_sparse_flash_attention_v3 import block_attn_bwd, bar_attn_bwd from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask, extract_kv, merge_kv + + if torch.version.hip is None: original_flags = sys.getdlopenflags() try: @@ -27,7 +30,8 @@ sys.setdlopenflags(original_flags) # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr - +# ------------------------------------------------------------------ +# CUDA-based Implementation def minfer_stripe_forward( process_group: dist.ProcessGroup, q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] @@ -155,7 +159,112 @@ def minfer_stripe_backward( d_kv_comm.wait() return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) +# ------------------------------------------------------------------ +# Triton-based Implementation +def minfer_stripe_triton_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group) + out, lse = None, None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (comm.rank - step) % comm.world_size + + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_stripe_triton_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (kv_comm.rank - step) % kv_comm.world_size + + dq, step_dk, step_dv = block_bar_attn_bwd( + dout, q, k, v, out, dq, None, None, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + deterministic=False, + step=offset, + causal=block_causal, + ) + + # Update dQ, dK, dV + if step == 0: + dk = step_dk + dv = step_dv + else: + d_kv_comm.wait() + + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +# ------------------------------------------------------------------ +# Attention Classes class MInferStripeFunc(torch.autograd.Function): @staticmethod def forward( @@ -234,75 +343,79 @@ def backward(ctx, dout, *args): dk = recover_striped_output(dk, dim=1, granularity=granularity, process_group=group) dv = recover_striped_output(dv, dim=1, granularity=granularity, process_group=group) return dq, dk, dv, None, None, None, None, None, None, None - -def minfer_stripe_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferStripeFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], + +class MInferStripeTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, v_size, s_size, layer_idx, softmax_scale, granularity, - return_attn_probs, + return_softmax, group, - ) + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + # built block_idx: [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + block_mask, bar_idx, bar_cnt, _, _, _ = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) -def minfer_stripe_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int = 0, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferStripeFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # slash attn + out, softmax_lse = minfer_stripe_triton_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity=granularity, + ) + + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + layer_idx = ctx.layer_idx + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + dq, dk, dv = minfer_stripe_triton_backward( + ctx.group, dout, q, k, v, out, softmax_lse, + layer_idx, ctx.softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity=ctx.granularity, + ) + + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + return dq, dk, dv, None, None, None, None, None, None, None + + +# ------------------------------------------------------------------ +# Wrapped Attention Functions +# ------------------ +# CUDA-based def minfer_stripe_func( # the one used for nnscaler training q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] @@ -326,15 +439,48 @@ def minfer_stripe_func( # the one used for nnscaler training assert alibi_slopes is None assert not deterministic - return MInferStripeFunc.apply( + if not use_triton(): + return MInferStripeFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + else: + print(f"Rank {dist.get_rank()} | minfer_stripe_func | using Triton implementation for MTraining w. Striped Ring Attention") + return MInferStripeTritonFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + +def minfer_stripe_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs +): + return minfer_stripe_func( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + *args, **kwargs + ) + + +def minfer_stripe_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + *args, **kwargs +): + return minfer_stripe_func( q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, + kv[:, :, 0], + kv[:, :, 1], + *args, **kwargs ) diff --git a/minference/dist_ops/minfer_striped_triton.py b/minference/dist_ops/minfer_striped_triton.py deleted file mode 100644 index 2946c3f..0000000 --- a/minference/dist_ops/minfer_striped_triton.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import torch -import torch.distributed as dist -from typing import List, Tuple, Dict - -from .utils import ( - RingComm, - shuffle_striped_input, recover_striped_output, -) -from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask -from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd, block_bar_attn_bwd - - -def minfer_stripe_triton_forward( - process_group: dist.ProcessGroup, - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - layer_idx: int, - softmax_scale: float, - block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - comm = RingComm(process_group) - out, lse = None, None - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k, next_v = comm.send_recv_kv(k, v) - block_causal = step == 0 - offset = (comm.rank - step) % comm.world_size - - - out, lse = block_bar_attn_fwd( - q, k, v, out, lse, softmax_scale, - bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], - granularity=granularity, - step=offset, - causal=block_causal, - ) - - if step + 1 != comm.world_size: - comm.wait() - k, v = next_k, next_v - - return out, lse - - -def minfer_stripe_triton_backward( - process_group: dist.ProcessGroup, - dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - layer_idx: int, - softmax_scale: float, - block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - granularity: int = 128, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k, next_v = kv_comm.send_recv_kv(k, v) - block_causal = step == 0 - offset = (kv_comm.rank - step) % kv_comm.world_size - - dq, step_dk, step_dv = block_bar_attn_bwd( - dout, q, k, v, out, dq, None, None, - softmax_lse, softmax_scale, - bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], - granularity=granularity, - deterministic=False, - step=offset, - causal=block_causal, - ) - - # Update dQ, dK, dV - if step == 0: - dk = step_dk - dv = step_dv - else: - d_kv_comm.wait() - - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - dk += step_dk - dv += step_dv - - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k, v = next_k, next_v - next_dk, next_dv = d_kv_comm.send_recv_kv( - dk, dv, dk_comm_buffer, dv_comm_buffer - ) - - d_kv_comm.wait() - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class MInferStripeTritonFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_softmax, - group, - ): - batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape - if softmax_scale is None: - softmax_scale = head_dim ** (-0.5) - - # built block_idx: [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks_local] - block_mask, bar_idx, bar_cnt, _, _, _ = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) - block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) - - q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) - k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) - v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) - - # slash attn - out, softmax_lse = minfer_stripe_triton_forward( - group, q, k, v, - layer_idx, softmax_scale, - block_idx, block_cnt, bar_idx, bar_cnt, - granularity=granularity, - ) - - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) - ctx.softmax_scale = softmax_scale - ctx.granularity = granularity - ctx.group = group - ctx.layer_idx = layer_idx - - out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) - if return_softmax: - softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) - return (out, softmax_lse, None) - return out - - @staticmethod - def backward(ctx, dout, *args): - layer_idx = ctx.layer_idx - dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) - q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors - - dq, dk, dv = minfer_stripe_triton_backward( - ctx.group, dout, q, k, v, out, softmax_lse, - layer_idx, ctx.softmax_scale, - block_idx, block_cnt, bar_idx, bar_cnt, - granularity=ctx.granularity, - ) - - dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) - dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) - dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) - - return dq, dk, dv, None, None, None, None, None, None, None - - -def minfer_stripe_triton_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferStripeTritonFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - group, - ) - - -def minfer_stripe_triton_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -): - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - return MInferStripeTritonFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - group, - ) - - -def minfer_stripe_triton_func( - q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - layer_idx: int, - dropout_p: float = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[int, int] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - group: dist.ProcessGroup = None, -) -> torch.Tensor: - assert causal - assert dropout_p == 0 - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - - return MInferStripeTritonFunc.apply( - q, - k, - v, - v_size, - s_size, - layer_idx, - softmax_scale, - granularity, - return_attn_probs, - group, - ) diff --git a/minference/dist_ops/minfer_zigzag.py b/minference/dist_ops/minfer_zigzag.py index d519cd1..e8b1d79 100644 --- a/minference/dist_ops/minfer_zigzag.py +++ b/minference/dist_ops/minfer_zigzag.py @@ -5,12 +5,10 @@ from typing import List, Tuple, Dict from .utils import ( - RingComm, - shuffle_zigzag_input, recover_zigzag_output, + RingComm, shuffle_zigzag_input, recover_zigzag_output, ) from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask -from minference.ops.pit_sparse_flash_attention_v3_triton import block_bar_attn_fwd -from minference.ops.pit_sparse_flash_attention_v3 import block_attn_bwd, bar_attn_bwd +from minference.ops.pit_sparse_flash_attention_v3 import block_bar_attn_fwd, block_attn_bwd, bar_attn_bwd def minfer_zigzag_forward( process_group: dist.ProcessGroup, diff --git a/minference/dist_ops/test/minfer_ring_test.py b/minference/dist_ops/test/minfer_ring_test.py index aa78fed..6eccbce 100644 --- a/minference/dist_ops/test/minfer_ring_test.py +++ b/minference/dist_ops/test/minfer_ring_test.py @@ -165,6 +165,7 @@ def _run_worker( @pytest.mark.parametrize("head_dim", [64, 128]) @pytest.mark.parametrize("sparsity", [0.9, 0.95]) @pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) +@pytest.mark.parametrize("use_triton", [True, False]) @pytest.mark.parametrize("attn_op_name", ["minfer_zigzag", "minfer_stripe", "minfer_dr_stripe"] ) @@ -174,6 +175,7 @@ def test_sparse_attention_kernels( head_dim: int, sparsity: float, num_qkv_head_pair: tuple[int, int], + use_triton: bool, attn_op_name: str, ): """ @@ -181,6 +183,12 @@ def test_sparse_attention_kernels( both forward pass and input-gradient w.r.t Q/K/V. """ port = str(random.randint(12000, 20000)) + if attn_op_name == "minfer_zigzag" and use_triton: + pytest.skip("minfer_zigzag is not implemented with the Triton path") + + if use_triton: + os.environ['FORCE_TRITON'] = "1" + cfg = SimpleNamespace( batch_size=batch_sz, seq_len=seq_len, @@ -202,3 +210,5 @@ def test_sparse_attention_kernels( nprocs=_WORLD_SIZE, join=True, ) + + os.environ['FORCE_TRITON'] = "0" diff --git a/minference/dist_ops/test/minfer_ring_test_raw.py b/minference/dist_ops/test/minfer_ring_test_raw.py index 97f1d76..181d86d 100644 --- a/minference/dist_ops/test/minfer_ring_test_raw.py +++ b/minference/dist_ops/test/minfer_ring_test_raw.py @@ -184,6 +184,7 @@ def test_sparse_attention_kernels( num_qo_heads: int, num_kv_heads: int, attn_op_name: str, + use_triton: bool, ): """ Compare every sparse kernel against the dense Flash-Attention reference on @@ -225,4 +226,5 @@ def test_sparse_attention_kernels( num_qo_heads=4, num_kv_heads=2, attn_op_name="minfer_zigzag", + use_triton=True ) \ No newline at end of file diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py index 47a89b6..5dfb641 100644 --- a/minference/dist_ops/xattn_zigzag.py +++ b/minference/dist_ops/xattn_zigzag.py @@ -14,9 +14,8 @@ from minference.ops.utils import use_triton from minference.ops.op_utils.xattn_utils import LN2, find_blocks_chunked from minference.ops.op_utils.vertical_slash_utils import convert_blockmask -from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd from minference.ops.xattention_fa import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum -from minference.ops.pit_sparse_flash_attention_v3_triton import triton_block_attn_fwd, triton_block_attn_bwd +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd, triton_block_attn_fwd, triton_block_attn_bwd def xattn_zigzag_estimate( diff --git a/minference/ops/op_utils/vertical_slash_utils.py b/minference/ops/op_utils/vertical_slash_utils.py index ba04cd5..eba1700 100644 --- a/minference/ops/op_utils/vertical_slash_utils.py +++ b/minference/ops/op_utils/vertical_slash_utils.py @@ -752,41 +752,6 @@ def build_index( ) return block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt - -def _build_mask_local( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], - s_size: List[int], - num_tokens: int, - granularity: int, - world_size: int = 1, - rank: int = 0, -): - with torch.no_grad(): - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity, world_size, rank) - batch_size, num_tokens, num_heads, head_dim = q.shape - num_blocks = block_mask.shape[-1] - num_tokens_pad = num_blocks * granularity - # Block Mask - mask = block_mask.unsqueeze(3).unsqueeze(5).repeat((1, 1, 1, granularity, 1, granularity)) - mask = mask.reshape((batch_size, num_heads, num_tokens_pad, num_tokens_pad)) - # Bar Mask - for batch_idx in range(batch_size): - for head_idx in range(num_heads): - for row_idx in range(num_blocks): - row_u = row_idx * granularity - row_d = row_u + granularity - bar_l = bar_cnt[batch_idx, head_idx, row_idx, rank] - bar_r = bar_cnt[batch_idx, head_idx, row_idx, rank + 1] - for col_idx in bar_idx[batch_idx, head_idx, row_idx, bar_l:bar_r]: - mask[batch_idx, head_idx, row_u:row_d, col_idx] = True - # Causal Mask - arange = torch.arange(0, num_tokens_pad, dtype=torch.int32, device=q.device) - mask.masked_fill_(arange[None, None, :, None] < arange[None, None, None, :], False) - return mask[:, :, :num_tokens, :num_tokens] - - def convert_blockmask( blockmask: torch.Tensor, # [world_size, batch_size, num_heads, num_blocks, num_blocks] block_size_M: int, diff --git a/minference/ops/pit_sparse_flash_attention_v3.py b/minference/ops/pit_sparse_flash_attention_v3.py index d044f6d..ce59bd9 100644 --- a/minference/ops/pit_sparse_flash_attention_v3.py +++ b/minference/ops/pit_sparse_flash_attention_v3.py @@ -26,9 +26,10 @@ sys.setdlopenflags(original_flags) # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr -from .op_utils.vertical_slash_utils import build_index_local - +from .op_utils.vertical_slash_utils import build_index_local, convert_blockmask +# ---------------------------------------------------------------------------- +# CUDA-based kernels (based on Block-Sparse-Attention) def block_attn_fwd( q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] @@ -428,246 +429,959 @@ def bar_attn_bwd( ) return dq, dk.to(dq.dtype), dv.to(dq.dtype) +# ---------------------------------------------------------------------------- +# Purely Triton-based kernels +@triton.jit +def _triton_block_attn_fwd_kernel( + Q, K, V, sm_scale, + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + if start_m * BLOCK_M >= num_tokens: + return -class MInferenceAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - v_size, - s_size, - softmax_scale, - granularity, - return_softmax, - deterministic, - ): - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - if softmax_scale is None: - softmax_scale = head_dim ** (-0.5) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh - # Block Mask - out, softmax_lse = block_attn_fwd( - q, k, v, softmax_scale, - block_mask, - granularity=granularity, - causal=True, - ) - # Bar Mask - out, softmax_lse = bar_attn_fwd( - q, k, v, out, softmax_lse, softmax_scale, - bar_idx, bar_cnt, - granularity=granularity, - step=0, - ) + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm - ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) - ctx.granularity = granularity - ctx.deterministic = deterministic - ctx.softmax_scale = softmax_scale - return (out, softmax_lse, None) if return_softmax else out + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + if block_num <= 0: + return - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors - # Block Mask - dq, dk, dv = block_attn_bwd( - dout, q, k, v, out, - softmax_lse, ctx.softmax_scale, - block_mask, - granularity=ctx.granularity, - deterministic=ctx.deterministic, - causal=True, - ) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im - # Bar Mask - dq, dk, dv = bar_attn_bwd( - dout, q, k, v, out, dq, dk, dv, - softmax_lse, ctx.softmax_scale, - bar_idx, bar_cnt, - granularity=ctx.granularity, - deterministic=ctx.deterministic, - step=0, - ) - return dq, dk, dv, None, None, None, None, None, None + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 -def minference_flash_attn_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - return MInferenceAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - deterministic, - ) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(Q.type.element_ty) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N -def minference_flash_attn_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - return MInferenceAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - deterministic, - ) + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, k) -def minference_flash_attn_func( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - return MInferenceAttnFunc.apply( - q, - k, - v, - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - deterministic, - ) + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new -def _build_mask_local( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], - s_size: List[int], - num_tokens: int, - granularity: int, - world_size: int = 1, - rank: int = 0, -): - with torch.no_grad(): - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity, world_size, rank) - batch_size, num_tokens, num_heads, head_dim = q.shape - num_blocks = block_mask.shape[-1] - num_tokens_pad = num_blocks * granularity - # Block Mask - mask = block_mask.unsqueeze(3).unsqueeze(5).repeat((1, 1, 1, granularity, 1, granularity)) - mask = mask.reshape((batch_size, num_heads, num_tokens_pad, num_tokens_pad)) - # Bar Mask - for batch_idx in range(batch_size): - for head_idx in range(num_heads): - for row_idx in range(num_blocks): - row_u = row_idx * granularity - row_d = row_u + granularity - bar_l = bar_cnt[batch_idx, head_idx, row_idx, rank] - bar_r = bar_cnt[batch_idx, head_idx, row_idx, rank + 1] - for col_idx in bar_idx[batch_idx, head_idx, row_idx, bar_l:bar_r]: - mask[batch_idx, head_idx, row_u:row_d, col_idx] = True - # Causal Mask - arange = torch.arange(0, num_tokens_pad, dtype=torch.int32, device=q.device) - mask.masked_fill_(arange[None, None, :, None] < arange[None, None, None, :], False) - return mask[:, :, :num_tokens, :num_tokens] - - -def _torch_sparse_attn_func( + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs).to(tl.float32) + s_0 = tl.load(lse_ptrs) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty)) + tl.store(lse_ptrs, s, mask=overflow_mask) + +def triton_block_attn_fwd( q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, + softmax_scale: float, + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + step: int = 0, causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, ): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic batch_size, num_tokens, num_qo_heads, head_dim = q.shape num_kv_heads = k.shape[2] - group_size = num_qo_heads // num_kv_heads - softmax_scale = head_dim ** -0.5 if softmax_scale is None else softmax_scale - mask = _build_mask_local(q, k, v_size, s_size, num_tokens, granularity) - k = k.repeat_interleave(group_size, dim=2) - v = v.repeat_interleave(group_size, dim=2) - p = torch.einsum('bmhd, bnhd -> bhmn', q * softmax_scale, k) - p = torch.where(mask, p, -torch.inf).to(torch.float32) - m = torch.max(p, dim=-1, keepdim=True).values.to(torch.float32) - p = torch.exp(p - m) - l = torch.sum(p, dim=-1, keepdim=True) - p = (p / l).to(q.dtype) - o = torch.einsum('bhmn, bnhd -> bmhd', p, v) - o = o.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) - if return_attn_probs: - lse = m + l.log() - return o, lse.squeeze(-1), None - return o - - -def _torch_sparse_attn_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] + num_blocks = block_idx.shape[2] + + o = torch.zeros_like(q) + lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf + + _triton_block_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, + block_cnt, block_idx, + o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return o, lse + +@triton.jit +def _triton_block_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) + + +def triton_block_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + deterministic: bool, + step: int = 0, + causal: bool = True, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = block_idx.shape[2] + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k, dtype=torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) + + _triton_block_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + block_cnt, block_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + +@triton.jit +def _triton_block_bar_attn_fwd_kernel( + Q, K, V, sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_1cz, stride_1ch, stride_1cm, stride_1cr, + stride_1iz, stride_1ih, stride_1im, stride_1in, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) + bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + if (bar_l >= bar_r) and (block_num <= 0): + return + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(Q.type.element_ty) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Bar + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(n_mask[None, :], qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs).to(tl.float32) + s_0 = tl.load(lse_ptrs) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty)) + tl.store(lse_ptrs, s, mask=overflow_mask) + + +def block_bar_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + step: int = 0, + causal: bool = True, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + if o is None: + o = torch.zeros_like(q) + lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf + _triton_block_bar_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, bar_cnt, bar_idx, block_cnt, block_idx, o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return o, lse + + +@triton.jit +def _triton_block_bar_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_1cz, stride_1ch, stride_1cm, stride_1cr, + stride_1iz, stride_1ih, stride_1im, stride_1in, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) + bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + if (bar_l >= bar_r) and (block_num <= 0): + return + + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Bar + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[:, None] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # Computer qk + qk = tl.where(n_mask[None, :], float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + cols[:, None] * stride_dvn, dv_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + cols[:, None] * stride_dkn, dk_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) + + +def block_bar_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dk: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + dv: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + deterministic: bool, + step: int = 0, + causal: bool = True, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + dq = torch.zeros_like(q) if dq is None else dq + dk = torch.zeros_like(k, dtype=torch.float32) if dk is None else dk.to(torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) if dv is None else dv.to(torch.float32) + _triton_block_bar_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + bar_cnt, bar_idx, block_cnt, block_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + +# --------------------------------------------------------------------------------- +# Attention Classes +class MInferenceAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + + # Block Mask + out, softmax_lse = block_attn_fwd( + q, k, v, softmax_scale, + block_mask, + granularity=granularity, + causal=True, + ) + # Bar Mask + out, softmax_lse = bar_attn_fwd( + q, k, v, out, softmax_lse, softmax_scale, + bar_idx, bar_cnt, + granularity=granularity, + step=0, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors + # Block Mask + dq, dk, dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, ctx.softmax_scale, + block_mask, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + causal=True, + ) + + # Bar Mask + dq, dk, dv = bar_attn_bwd( + dout, q, k, v, out, dq, dk, dv, + softmax_lse, ctx.softmax_scale, + bar_idx, bar_cnt, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + step=0, + ) + return dq, dk, dv, None, None, None, None, None, None + + + +class MInferenceAttnTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + out, softmax_lse = block_bar_attn_fwd( + q, k, v, None, None, softmax_scale, + bar_idx, bar_cnt, block_idx, block_cnt, + granularity=granularity, + step=0, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + # Bar Mask + dq, dk, dv = block_bar_attn_bwd( + dout, q, k, v, out, None, None, None, + softmax_lse, ctx.softmax_scale, + bar_idx, bar_cnt, block_idx, block_cnt, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + step=0, + ) + + return dq, dk, dv, None, None, None, None, None, None + +# --------------------------------------------------------------------------------- +# Wrapped Attention Functions +# -------------------------------------------- +# CUDA-Based +def minference_flash_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] dropout_p: int = 0.0, softmax_scale: float = None, granularity: int = 128, @@ -677,26 +1391,53 @@ def _torch_sparse_attn_qkvpacked_func( deterministic: bool = False, return_attn_probs: bool = False, ): - return _torch_sparse_attn_func( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnFunc.apply( + q, + k, + v, v_size, s_size, - dropout_p, softmax_scale, granularity, - causal, - window_size, - alibi_slopes, - deterministic, return_attn_probs, + deterministic, + ) + + +def minference_flash_attn_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs +): + return minference_flash_attn_func( + qkv[:, :, 0], # q + qkv[:, :, 1], # k + qkv[:, :, 2], # v + *args, **kwargs ) -def _torch_sparse_attn_kvpacked_func( +def minference_flash_attn_kvpacked_func( q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim]\ + *args, **kwargs +): + return minference_flash_attn_func( + q, + kv[:, :, 0], # k + kv[:, :, 1], # v + *args, **kwargs + ) + +# -------------------------------------------- +# Triton-Based +def minference_flash_attn_triton_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] v_size: List[int], # [num_qo_heads] s_size: List[int], # [num_qo_heads] dropout_p: int = 0.0, @@ -708,18 +1449,42 @@ def _torch_sparse_attn_kvpacked_func( deterministic: bool = False, return_attn_probs: bool = False, ): - return _torch_sparse_attn_func( + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnTritonFunc.apply( q, - kv[:, :, 0], - kv[:, :, 1], + k, + v, v_size, s_size, - dropout_p, softmax_scale, granularity, - causal, - window_size, - alibi_slopes, - deterministic, return_attn_probs, + deterministic, + ) + +def minference_flash_attn_triton_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs +): + return minference_flash_attn_triton_func( + qkv[:, :, 0], # q + qkv[:, :, 1], # k + qkv[:, :, 2], # v + *args, **kwargs + ) + + +def minference_flash_attn_triton_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + *args, **kwargs +): + return minference_flash_attn_triton_func( + q, + kv[:, :, 0], # k + kv[:, :, 1], # v + *args, **kwargs ) diff --git a/minference/ops/pit_sparse_flash_attention_v3_triton.py b/minference/ops/pit_sparse_flash_attention_v3_triton.py deleted file mode 100644 index 068085b..0000000 --- a/minference/ops/pit_sparse_flash_attention_v3_triton.py +++ /dev/null @@ -1,1076 +0,0 @@ -import torch -import triton -import triton.language as tl -from typing import List, Tuple - -from .op_utils.vertical_slash_utils import ( - build_index_local, _build_mask_local, convert_blockmask, -) - - -@triton.jit -def _triton_block_attn_fwd_kernel( - Q, K, V, sm_scale, - block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] - block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] - Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] - softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] - stride_qz, stride_qh, stride_qm, stride_qd, - stride_kz, stride_kh, stride_kn, stride_kd, - stride_vz, stride_vh, stride_vn, stride_vd, - stride_oz, stride_oh, stride_om, stride_od, - stride_2cz, stride_2ch, stride_2cm, - stride_2iz, stride_2ih, stride_2im, stride_2in, - stride_sz, stride_sh, stride_sm, - num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - CAUSAL: tl.constexpr, -): - start_m = tl.program_id(0) - qo_head_idx = tl.program_id(1) - batch_idx = tl.program_id(2) - kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) - - if start_m * BLOCK_M >= num_tokens: - return - - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh - kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh - - q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd - v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd - o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od - lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm - - block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) - if block_num <= 0: - return - - block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x - qk_scale = sm_scale * 1.44269504 - - # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) - q = (q * qk_scale).to(Q.type.element_ty) - - if CAUSAL: - block_split = block_num - 2 - else: - block_split = block_num - - # Block - for start_n in range(0, block_split): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # -- compute qk -- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = qk + tl.dot(q, k) - - # -- compute scaling constant -- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc = acc * acc_scale[:, None] - acc = acc + tl.dot(p.to(Q.type.element_ty), v) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - - # Block (Causal) - for start_n in range(max(block_split, 0), block_num): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # -- compute qk -- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) - qk = qk + tl.dot(q, k) - - # -- compute scaling constant -- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc = acc * acc_scale[:, None] - acc = acc + tl.dot(p.to(Q.type.element_ty), v) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - - # write back O and LSE - acc_1 = acc / l_i[:, None] - s_1 = m_i * 0.69314718 + tl.math.log(l_i) - acc_0 = tl.load(o_ptrs).to(tl.float32) - s_0 = tl.load(lse_ptrs) - - overflow_mask = (s_0 - s_1) < 88.0 - - theta = tl.math.exp(s_0 - s_1) - alpha_0 = 1 / (1 + 1 / theta) - alpha_1 = 1 / (1 + theta) - acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 - s = s_1 - tl.math.log(alpha_1) - - tl.store(o_ptrs, acc.to(Out.type.element_ty)) - tl.store(lse_ptrs, s, mask=overflow_mask) - -def triton_block_attn_fwd( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - softmax_scale: float, - block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] - granularity: int, - step: int = 0, - causal: bool = True, -): - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - num_blocks = block_idx.shape[2] - - o = torch.zeros_like(q) - lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf - - _triton_block_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( - q, k, v, softmax_scale, - block_cnt, block_idx, - o, lse, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(1), k.stride(3), - v.stride(0), v.stride(2), v.stride(1), v.stride(3), - o.stride(0), o.stride(2), o.stride(1), o.stride(3), - block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), - block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), - lse.stride(0), lse.stride(1), lse.stride(2), - num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, - num_warps=4, num_stages=2, - ) - return o, lse - -@triton.jit -def _triton_block_attn_bwd_kernel( - Q, K, V, O, - DQ, DK, DV, DO, - sm_scale, - block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] - block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] - softmax_lse, # [BATCH, N_HEADS, N_CTX] - stride_qz, stride_qh, stride_qm, stride_qd, - stride_kz, stride_kh, stride_kn, stride_kd, - stride_vz, stride_vh, stride_vn, stride_vd, - stride_oz, stride_oh, stride_om, stride_od, - stride_dqz, stride_dqh, stride_dqm, stride_dqd, - stride_dkz, stride_dkh, stride_dkn, stride_dkd, - stride_dvz, stride_dvh, stride_dvn, stride_dvd, - stride_doz, stride_doh, stride_dom, stride_dod, - stride_2cz, stride_2ch, stride_2cm, - stride_2iz, stride_2ih, stride_2im, stride_2in, - stride_sz, stride_sh, stride_sm, - num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - CAUSAL: tl.constexpr, -): - start_m = tl.program_id(0) - qo_head_idx = tl.program_id(1) - batch_idx = tl.program_id(2) - kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) - - if start_m * BLOCK_M >= num_tokens: - return - - qk_scale = sm_scale * 1.44269504 - - # offset pointers for batch/head - Q += batch_idx * stride_qz + qo_head_idx * stride_qh - K += batch_idx * stride_kz + kv_head_idx * stride_kh - V += batch_idx * stride_vz + kv_head_idx * stride_vh - O += batch_idx * stride_oz + qo_head_idx * stride_oh - DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh - DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh - DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh - DO += batch_idx * stride_doz + qo_head_idx * stride_doh - - # loop over rows - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - - # initialize pointers to value-like data - q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_ptrs = K + offs_d[None, :] * stride_kd - v_ptrs = V + offs_d[None, :] * stride_vd - o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od - dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd - dk_ptrs = DK + offs_d[None, :] * stride_dkd - dv_ptrs = DV + offs_d[None, :] * stride_dvd - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod - l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm - - block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) - block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im - - o = tl.load(o_ptrs).to(tl.float32) - do = tl.load(do_ptrs).to(tl.float32) - d_i = tl.sum(o * do, axis=1) - - q = tl.load(q_ptrs) - do = do.to(DO.dtype.element_ty) - l_i = tl.load(l_ptrs) * 1.44269504 - - dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if CAUSAL: - block_split = block_num - 2 - else: - block_split = block_num - - # Block - for start_n in range(0, block_split): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # Computer qk - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = qk + tl.dot(q, tl.trans(k)) - qk = qk * qk_scale - p = tl.math.exp2(qk - l_i[:, None]) - - # compute dv - dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) - tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") - - # compute dp = dot(v, do) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] - dp = dp + tl.dot(do, tl.trans(v)) - - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - - # compute dk = dot(ds.T, q) - dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) - tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") - - # compute dq - dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) - - # Block (Causal) - for start_n in range(max(block_split, 0), block_num): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # Computer qk - qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) - qk = qk + tl.dot(q, tl.trans(k)) - qk = qk * qk_scale - p = tl.math.exp2(qk - l_i[:, None]) - - # compute dv - dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) - tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") - - # compute dp = dot(v, do) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] - dp = dp + tl.dot(do, tl.trans(v)) - - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - - # compute dk = dot(ds.T, q) - dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) - tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") - - # compute dq - dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) - - dq_old = tl.load(dq_ptrs).to(tl.float32) - tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) - - -def triton_block_attn_bwd( - grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - softmax_scale: float, - block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] - granularity: int, - deterministic: bool, - step: int = 0, - causal: bool = True, -): - assert not deterministic - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - num_blocks = block_idx.shape[2] - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k, dtype=torch.float32) - dv = torch.zeros_like(v, dtype=torch.float32) - - _triton_block_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( - q, k, v, o, dq, dk, dv, grad, softmax_scale, - block_cnt, block_idx, softmax_lse, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(1), k.stride(3), - v.stride(0), v.stride(2), v.stride(1), v.stride(3), - o.stride(0), o.stride(2), o.stride(1), o.stride(3), - dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), - dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), - dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), - grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), - block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), - block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), - softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), - num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, - num_warps=4, num_stages=2, - ) - return dq, dk.to(dq.dtype), dv.to(dq.dtype) - - -@triton.jit -def _triton_block_bar_attn_fwd_kernel( - Q, K, V, sm_scale, - bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] - bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] - block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] - block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] - Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] - softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] - stride_qz, stride_qh, stride_qm, stride_qd, - stride_kz, stride_kh, stride_kn, stride_kd, - stride_vz, stride_vh, stride_vn, stride_vd, - stride_oz, stride_oh, stride_om, stride_od, - stride_1cz, stride_1ch, stride_1cm, stride_1cr, - stride_1iz, stride_1ih, stride_1im, stride_1in, - stride_2cz, stride_2ch, stride_2cm, - stride_2iz, stride_2ih, stride_2im, stride_2in, - stride_sz, stride_sh, stride_sm, - step, num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - CAUSAL: tl.constexpr, -): - start_m = tl.program_id(0) - qo_head_idx = tl.program_id(1) - batch_idx = tl.program_id(2) - kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) - - if start_m * BLOCK_M >= num_tokens: - return - - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh - kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh - - q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd - v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd - o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od - - lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm - - bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) - bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) - bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im - - block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) - block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im - - if (bar_l >= bar_r) and (block_num <= 0): - return - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x - qk_scale = sm_scale * 1.44269504 - - # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) - q = (q * qk_scale).to(Q.type.element_ty) - - if CAUSAL: - block_split = block_num - 2 - else: - block_split = block_num - - # Block - for start_n in range(0, block_split): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # -- compute qk -- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = qk + tl.dot(q, k) - - # -- compute scaling constant -- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc = acc * acc_scale[:, None] - acc = acc + tl.dot(p.to(Q.type.element_ty), v) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - - # Block (Causal) - for start_n in range(max(block_split, 0), block_num): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # -- compute qk -- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) - qk = qk + tl.dot(q, k) - - # -- compute scaling constant -- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc = acc * acc_scale[:, None] - acc = acc + tl.dot(p.to(Q.type.element_ty), v) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - - # Bar - for start_n in range(bar_l, bar_r, BLOCK_N): - n_mask = start_n + offs_n < bar_r - cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) - - # -- load k, v -- - k = tl.load(k_ptrs + cols[None, :] * stride_kn) - v = tl.load(v_ptrs + cols[:, None] * stride_vn) - - # -- compute qk -- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.where(n_mask[None, :], qk, float("-inf")) - qk = qk + tl.dot(q, k) - - # -- compute scaling constant -- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc = acc * acc_scale[:, None] - acc = acc + tl.dot(p.to(Q.type.element_ty), v) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - - # write back O and LSE - acc_1 = acc / l_i[:, None] - s_1 = m_i * 0.69314718 + tl.math.log(l_i) - acc_0 = tl.load(o_ptrs).to(tl.float32) - s_0 = tl.load(lse_ptrs) - - overflow_mask = (s_0 - s_1) < 88.0 - - theta = tl.math.exp(s_0 - s_1) - alpha_0 = 1 / (1 + 1 / theta) - alpha_1 = 1 / (1 + theta) - acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 - s = s_1 - tl.math.log(alpha_1) - - tl.store(o_ptrs, acc.to(Out.type.element_ty)) - tl.store(lse_ptrs, s, mask=overflow_mask) - - -def block_bar_attn_fwd( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - softmax_scale: float, - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] - granularity: int, - step: int = 0, - causal: bool = True, -): - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - num_blocks = bar_idx.shape[2] - if o is None: - o = torch.zeros_like(q) - lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf - _triton_block_bar_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( - q, k, v, softmax_scale, bar_cnt, bar_idx, block_cnt, block_idx, o, lse, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(1), k.stride(3), - v.stride(0), v.stride(2), v.stride(1), v.stride(3), - o.stride(0), o.stride(2), o.stride(1), o.stride(3), - bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), - bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), - block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), - block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), - lse.stride(0), lse.stride(1), lse.stride(2), - step, num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, - num_warps=4, num_stages=2, - ) - return o, lse - - -@triton.jit -def _triton_block_bar_attn_bwd_kernel( - Q, K, V, O, - DQ, DK, DV, DO, - sm_scale, - bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] - bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] - block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] - block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] - softmax_lse, # [BATCH, N_HEADS, N_CTX] - stride_qz, stride_qh, stride_qm, stride_qd, - stride_kz, stride_kh, stride_kn, stride_kd, - stride_vz, stride_vh, stride_vn, stride_vd, - stride_oz, stride_oh, stride_om, stride_od, - stride_dqz, stride_dqh, stride_dqm, stride_dqd, - stride_dkz, stride_dkh, stride_dkn, stride_dkd, - stride_dvz, stride_dvh, stride_dvn, stride_dvd, - stride_doz, stride_doh, stride_dom, stride_dod, - stride_1cz, stride_1ch, stride_1cm, stride_1cr, - stride_1iz, stride_1ih, stride_1im, stride_1in, - stride_2cz, stride_2ch, stride_2cm, - stride_2iz, stride_2ih, stride_2im, stride_2in, - stride_sz, stride_sh, stride_sm, - step, num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - CAUSAL: tl.constexpr, -): - start_m = tl.program_id(0) - qo_head_idx = tl.program_id(1) - batch_idx = tl.program_id(2) - kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) - - if start_m * BLOCK_M >= num_tokens: - return - - qk_scale = sm_scale * 1.44269504 - - # offset pointers for batch/head - Q += batch_idx * stride_qz + qo_head_idx * stride_qh - K += batch_idx * stride_kz + kv_head_idx * stride_kh - V += batch_idx * stride_vz + kv_head_idx * stride_vh - O += batch_idx * stride_oz + qo_head_idx * stride_oh - DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh - DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh - DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh - DO += batch_idx * stride_doz + qo_head_idx * stride_doh - - # loop over rows - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - - # initialize pointers to value-like data - q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_ptrs = K + offs_d[None, :] * stride_kd - v_ptrs = V + offs_d[None, :] * stride_vd - o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od - dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd - dk_ptrs = DK + offs_d[None, :] * stride_dkd - dv_ptrs = DV + offs_d[None, :] * stride_dvd - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod - - l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm - - bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) - bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) - bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im - - block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) - block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im - - if (bar_l >= bar_r) and (block_num <= 0): - return - - o = tl.load(o_ptrs).to(tl.float32) - do = tl.load(do_ptrs).to(tl.float32) - d_i = tl.sum(o * do, axis=1) - - q = tl.load(q_ptrs) - do = do.to(DO.dtype.element_ty) - l_i = tl.load(l_ptrs) * 1.44269504 - - dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if CAUSAL: - block_split = block_num - 2 - else: - block_split = block_num - - # Block - for start_n in range(0, block_split): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # Computer qk - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = qk + tl.dot(q, tl.trans(k)) - qk = qk * qk_scale - p = tl.math.exp2(qk - l_i[:, None]) - - # compute dv - dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) - tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") - - # compute dp = dot(v, do) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] - dp = dp + tl.dot(do, tl.trans(v)) - - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - - # compute dk = dot(ds.T, q) - dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) - tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") - - # compute dq - dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) - - # Block (Causal) - for start_n in range(max(block_split, 0), block_num): - block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N - - # -- load k, v -- - k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) - v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) - - # Computer qk - qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) - qk = qk + tl.dot(q, tl.trans(k)) - qk = qk * qk_scale - p = tl.math.exp2(qk - l_i[:, None]) - - # compute dv - dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) - tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") - - # compute dp = dot(v, do) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] - dp = dp + tl.dot(do, tl.trans(v)) - - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - - # compute dk = dot(ds.T, q) - dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) - tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") - - # compute dq - dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) - - # Bar - for start_n in range(bar_l, bar_r, BLOCK_N): - n_mask = start_n + offs_n < bar_r - cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) - - # -- load k, v -- - k = tl.load(k_ptrs + cols[:, None] * stride_kn) - v = tl.load(v_ptrs + cols[:, None] * stride_vn) - - # Computer qk - qk = tl.where(n_mask[None, :], float(0.), float("-inf")) - qk = qk + tl.dot(q, tl.trans(k)) - qk = qk * qk_scale - p = tl.math.exp2(qk - l_i[:, None]) - - # compute dv - dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) - tl.atomic_add(dv_ptrs + cols[:, None] * stride_dvn, dv_vals, mask=n_mask[:, None], sem="relaxed") - - # compute dp = dot(v, do) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] - dp = dp + tl.dot(do, tl.trans(v)) - - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - - # compute dk = dot(ds.T, q) - dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) - tl.atomic_add(dk_ptrs + cols[:, None] * stride_dkn, dk_vals, mask=n_mask[:, None], sem="relaxed") - - # compute dq - dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) - - dq_old = tl.load(dq_ptrs).to(tl.float32) - tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) - - -def block_bar_attn_bwd( - grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - dk: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - dv: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] - softmax_scale: float, - bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] - bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] - block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] - block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] - granularity: int, - deterministic: bool, - step: int = 0, - causal: bool = True, -): - assert not deterministic - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - num_blocks = bar_idx.shape[2] - dq = torch.zeros_like(q) if dq is None else dq - dk = torch.zeros_like(k, dtype=torch.float32) if dk is None else dk.to(torch.float32) - dv = torch.zeros_like(v, dtype=torch.float32) if dv is None else dv.to(torch.float32) - _triton_block_bar_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( - q, k, v, o, dq, dk, dv, grad, softmax_scale, - bar_cnt, bar_idx, block_cnt, block_idx, softmax_lse, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(1), k.stride(3), - v.stride(0), v.stride(2), v.stride(1), v.stride(3), - o.stride(0), o.stride(2), o.stride(1), o.stride(3), - dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), - dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), - dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), - grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), - bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), - bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), - block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), - block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), - softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), - step, num_qo_heads, num_kv_heads, num_tokens, - BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, - num_warps=4, num_stages=2, - ) - return dq, dk.to(dq.dtype), dv.to(dq.dtype) - - -class MInferenceAttnTritonFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - v_size, - s_size, - softmax_scale, - granularity, - return_softmax, - deterministic, - ): - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - if softmax_scale is None: - softmax_scale = head_dim ** (-0.5) - - block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) - block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) - - out, softmax_lse = block_bar_attn_fwd( - q, k, v, None, None, softmax_scale, - bar_idx, bar_cnt, block_idx, block_cnt, - granularity=granularity, - step=0, - ) - - ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) - ctx.granularity = granularity - ctx.deterministic = deterministic - ctx.softmax_scale = softmax_scale - return (out, softmax_lse, None) if return_softmax else out - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors - - # Bar Mask - dq, dk, dv = block_bar_attn_bwd( - dout, q, k, v, out, None, None, None, - softmax_lse, ctx.softmax_scale, - bar_idx, bar_cnt, block_idx, block_cnt, - granularity=ctx.granularity, - deterministic=ctx.deterministic, - step=0, - ) - - return dq, dk, dv, None, None, None, None, None, None - -def minference_flash_attn_triton_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - return MInferenceAttnTritonFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - deterministic, - ) - - -def minference_flash_attn_triton_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - return MInferenceAttnTritonFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - deterministic, - ) - - -def minference_flash_attn_triton_func( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - return MInferenceAttnTritonFunc.apply( - q, - k, - v, - v_size, - s_size, - softmax_scale, - granularity, - return_attn_probs, - deterministic, - ) - - -def _torch_sparse_attn_func( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - assert dropout_p == 0 - assert causal - assert window_size == (-1, -1) - assert alibi_slopes is None - assert not deterministic - batch_size, num_tokens, num_qo_heads, head_dim = q.shape - num_kv_heads = k.shape[2] - group_size = num_qo_heads // num_kv_heads - softmax_scale = head_dim ** -0.5 if softmax_scale is None else softmax_scale - mask = _build_mask_local(q, k, v_size, s_size, num_tokens, granularity) - k = k.repeat_interleave(group_size, dim=2) - v = v.repeat_interleave(group_size, dim=2) - p = torch.einsum('bmhd, bnhd -> bhmn', q * softmax_scale, k) - p = torch.where(mask, p, -torch.inf).to(torch.float32) - m = torch.max(p, dim=-1, keepdim=True).values.to(torch.float32) - p = torch.exp(p - m) - l = torch.sum(p, dim=-1, keepdim=True) - p = (p / l).to(q.dtype) - o = torch.einsum('bhmn, bnhd -> bmhd', p, v) - o = o.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) - if return_attn_probs: - lse = m + l.log() - return o, lse.squeeze(-1), None - return o - - -def _torch_sparse_attn_qkvpacked_func( - qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] - v_size: List[int], # [num_heads] - s_size: List[int], # [num_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - return _torch_sparse_attn_func( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - v_size, - s_size, - dropout_p, - softmax_scale, - granularity, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def _torch_sparse_attn_kvpacked_func( - q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] - kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] - v_size: List[int], # [num_qo_heads] - s_size: List[int], # [num_qo_heads] - dropout_p: int = 0.0, - softmax_scale: float = None, - granularity: int = 128, - causal: bool = True, - window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window - alibi_slopes: Tuple[float, float] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -): - return _torch_sparse_attn_func( - q, - kv[:, :, 0], - kv[:, :, 1], - v_size, - s_size, - dropout_p, - softmax_scale, - granularity, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - diff --git a/mtraining/attn_funcs/minfer_func.py b/mtraining/attn_funcs/minfer_func.py index 32cac5c..c7025c6 100644 --- a/mtraining/attn_funcs/minfer_func.py +++ b/mtraining/attn_funcs/minfer_func.py @@ -19,11 +19,11 @@ from nnscaler.ir.operator import IRFwOperation from minference.ops.utils import use_triton -from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func -from minference.ops.pit_sparse_flash_attention_v3_triton import minference_flash_attn_triton_func +from minference.ops.pit_sparse_flash_attention_v3 import ( + minference_flash_attn_func, minference_flash_attn_triton_func +) from minference.dist_ops import ( - minfer_stripe_func, minfer_stripe_triton_func, - minfer_zigzag_func, minfer_dr_stripe_func, minfer_dr_stripe_triton_func, + minfer_stripe_func, minfer_zigzag_func, minfer_dr_stripe_func, ) @@ -88,8 +88,6 @@ def minfer_stripe_op( q_len: int, head_dim: int, layer_idx: int, - - pattern_dict: Dict[int, Tuple[str, int, int, int]], attn_dropout: float=0., granularity: int = 128, @@ -108,38 +106,22 @@ def minfer_stripe_op( v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] - if not use_triton(): - attn_output = minfer_stripe_func( - query_states.transpose(1, 2).contiguous(), - key_states.transpose(1, 2).contiguous(), - value_states.transpose(1, 2).contiguous(), - v_sizes, s_sizes, - layer_idx, - attn_dropout, - softmax_scale=None, - granularity=granularity, - causal=True, - window_size=(-1, -1), - deterministic=False, - return_attn_probs=False, - group=group, - ) # expect: b {q_anno} l^ vd^' - else: - attn_output = minfer_stripe_triton_func( - query_states.transpose(1, 2).contiguous(), - key_states.transpose(1, 2).contiguous(), - value_states.transpose(1, 2).contiguous(), - v_sizes, s_sizes, - layer_idx, - attn_dropout, - softmax_scale=None, - granularity=granularity, - causal=True, - window_size=(-1, -1), - deterministic=False, - return_attn_probs=False, - group=group, - ) + attn_output = minfer_stripe_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + return attn_output.contiguous() @@ -224,39 +206,21 @@ def minfer_dr_stripe_op( v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] - if not use_triton(): - attn_output = minfer_dr_stripe_func( - query_states.transpose(1, 2).contiguous(), - key_states.transpose(1, 2).contiguous(), - value_states.transpose(1, 2).contiguous(), - v_sizes, s_sizes, - layer_idx, - attn_dropout, - softmax_scale=None, - granularity=granularity, - causal=True, - window_size=(-1, -1), - deterministic=False, - return_attn_probs=False, - group=group, - ) # expect: b {q_anno} l^ vd^' - else: - attn_output = minfer_dr_stripe_triton_func( - query_states.transpose(1, 2).contiguous(), - key_states.transpose(1, 2).contiguous(), - value_states.transpose(1, 2).contiguous(), - v_sizes, s_sizes, - layer_idx, - attn_dropout, - softmax_scale=None, - granularity=granularity, - causal=True, - window_size=(-1, -1), - deterministic=False, - return_attn_probs=False, - group=group, - ) - + attn_output = minfer_dr_stripe_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' return attn_output.contiguous()