diff --git a/.gitignore b/.gitignore index 16033cd..ae5d0fc 100644 --- a/.gitignore +++ b/.gitignore @@ -415,3 +415,4 @@ build/ *.egg-info/ *.so dist +*.eggs/ \ No newline at end of file diff --git a/minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json b/minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json new file mode 100644 index 0000000..deab2fa --- /dev/null +++ b/minference/configs/Phi-3-mini-4k-instruct-LongRoPE-128k.json @@ -0,0 +1,6210 @@ +[ + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.3332791030406952 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.4416683614253998 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.4923180937767029 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.3880477547645569 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.38015398383140564 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.34974828362464905 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.4002125859260559 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.8959001898765564 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.44538143277168274 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.364785760641098 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.4276016652584076 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.3688332438468933 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.46328139305114746 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.3348214626312256 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.392171710729599 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.34772568941116333 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.4066217541694641 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.360747754573822 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.34830111265182495 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.47614070773124695 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.47739607095718384 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.376796156167984 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.5052061080932617 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7984429001808167 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.4089720547199249 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.512876033782959 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.45735129714012146 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8220791220664978 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9416881203651428 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.4253023564815521 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.6658170819282532 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.5853910446166992 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.5987299084663391 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.5018280148506165 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.4191092550754547 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5811007618904114 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5408463478088379 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.47298333048820496 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.786112368106842 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.6012060642242432 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.5791159272193909 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9051483869552612 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.58222496509552 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.48059725761413574 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7023431062698364 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.43349939584732056 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5206228494644165 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.6523349285125732 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.605652391910553 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.4064832925796509 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.5884730219841003 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.8767980337142944 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.9045116305351257 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.7437348365783691 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.48772233724594116 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.46409618854522705 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.49005016684532166 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.6720733046531677 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.9864497184753418 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8722768425941467 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.4877939820289612 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.575534462928772 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.42707663774490356 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6121442317962646 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7372177839279175 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7894041538238525 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.5279066562652588 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.4761664569377899 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.4725611209869385 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8257285952568054 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.5990859866142273 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.6145595908164978 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7928207516670227 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.707308292388916 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.7120367884635925 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.687991201877594 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.673988401889801 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8554307222366333 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.755358874797821 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7254285216331482 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.7025570273399353 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9539948105812073 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.6157228946685791 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.733711838722229 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.5779848694801331 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.6323117017745972 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.5607789754867554 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.6604608297348022 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.8311918377876282 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.5358595848083496 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.8005751967430115 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.5391202569007874 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.7308750152587891 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6740477085113525 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.5843774676322937 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.9640750885009766 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9977744221687317 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.997894287109375 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9970235824584961 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.9453125 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9917697310447693 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9911014437675476 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9628551602363586 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9096153378486633 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9951955080032349 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9851846098899841 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.996168315410614 + ], + "16": [ + "vertical_and_slash", + 100, + 800, + 0.96484375 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9974663853645325 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.995905339717865 + ], + "19": [ + "vertical_and_slash", + 30, + 800, + 0.998296320438385 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9968391060829163 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9973934888839722 + ], + "22": [ + "vertical_and_slash", + 30, + 800, + 0.998639702796936 + ], + "23": [ + "vertical_and_slash", + 30, + 800, + 0.9961214661598206 + ], + "24": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "25": [ + "vertical_and_slash", + 30, + 800, + 0.9963303804397583 + ], + "26": [ + "vertical_and_slash", + 100, + 800, + 0.96484375 + ], + "27": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "28": [ + "vertical_and_slash", + 30, + 800, + 0.9944915175437927 + ], + "29": [ + "vertical_and_slash", + 100, + 800, + 0.9453125 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.7727152705192566 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9912976622581482 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 0.9765625 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9993456602096558 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 1.0000362396240234 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9966914057731628 + ], + "4": [ + "vertical_and_slash", + 30, + 800, + 0.8796641826629639 + ], + "5": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9982466697692871 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.96875 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9937134385108948 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9967017769813538 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9959616661071777 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9792157411575317 + ], + "15": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9846665859222412 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9296243190765381 + ], + "18": [ + "vertical_and_slash", + 30, + 800, + 0.9900624752044678 + ], + "19": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "20": [ + "vertical_and_slash", + 30, + 800, + 0.9983267188072205 + ], + "21": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.9900224804878235 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.997448205947876 + ], + "24": [ + "vertical_and_slash", + 30, + 800, + 0.9974275827407837 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9921515583992004 + ], + "26": [ + "vertical_and_slash", + 30, + 800, + 0.9961837530136108 + ], + "27": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9986920952796936 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9839296936988831 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9926077127456665 + ], + "31": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9920685291290283 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9854402542114258 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9949761033058167 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8777214288711548 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9941083192825317 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9981333613395691 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9964705109596252 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9998161196708679 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9989191293716431 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9850229620933533 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9929664731025696 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9970735311508179 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9854613542556763 + ], + "14": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.998539924621582 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9988572597503662 + ], + "17": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "18": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9937215447425842 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9997386336326599 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9956527352333069 + ], + "22": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9980509877204895 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.9594733119010925 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9973872900009155 + ], + "26": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8722200989723206 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9793213605880737 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.9947999119758606 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.999260425567627 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.991969883441925 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9714643955230713 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9889455437660217 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9993591904640198 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9979987740516663 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9925215244293213 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.9974944591522217 + ], + "7": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.733367919921875 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9970583319664001 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.997306227684021 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "12": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9741608500480652 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9867183566093445 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9974902868270874 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.8923004865646362 + ], + "17": [ + "vertical_and_slash", + 30, + 800, + 0.9976084232330322 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9930179119110107 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9756211638450623 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9988991022109985 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9600447416305542 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9966569542884827 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.9705502986907959 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.996631383895874 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.7431671023368835 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9833155274391174 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.995357096195221 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9988372921943665 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6221311688423157 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.962378740310669 + ], + "31": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8013869524002075 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9569835662841797 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9765567779541016 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9659812450408936 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.893607497215271 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9978155493736267 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9035964608192444 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9964177012443542 + ], + "8": [ + "vertical_and_slash", + 500, + 700, + 0.9885393381118774 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9989112019538879 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9980699419975281 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9978247284889221 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9864377379417419 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9934346675872803 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.8716491460800171 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9965083003044128 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9953898191452026 + ], + "17": [ + "vertical_and_slash", + 30, + 800, + 0.906114399433136 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.9365297555923462 + ], + "19": [ + "vertical_and_slash", + 30, + 800, + 0.9918379783630371 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.998047411441803 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9964088797569275 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.9905833601951599 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9906750917434692 + ], + "24": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9951276183128357 + ], + "26": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9964181780815125 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9454283118247986 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.98893141746521 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9951130747795105 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.997049868106842 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.706658124923706 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9980204701423645 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6844134330749512 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9990592002868652 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.8842496871948242 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9998780488967896 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9981557130813599 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9960853457450867 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9998636841773987 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9711626172065735 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9983590245246887 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9998638033866882 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.8798794150352478 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9990912079811096 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9776530265808105 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9992070198059082 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.999956488609314 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9974350929260254 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9929025173187256 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9999449253082275 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9973562359809875 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9999447464942932 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9998283982276917 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9996301531791687 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.99957674741745 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.8067362904548645 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9952501654624939 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.987522304058075 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9994516968727112 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.9991733431816101 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9994316697120667 + ], + "31": [ + "vertical_and_slash", + 500, + 700, + 0.9992942214012146 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9960434436798096 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9954621195793152 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9846372604370117 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.866407036781311 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9988003373146057 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9943948984146118 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9850127696990967 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9895413517951965 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.8989573121070862 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9728281497955322 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9660746455192566 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9935301542282104 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9088181257247925 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9997004270553589 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9992449879646301 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9945877194404602 + ], + "16": [ + "vertical_and_slash", + 500, + 700, + 0.9919179677963257 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9887120127677917 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.8854116797447205 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9977339506149292 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9949434995651245 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9963594675064087 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8057686686515808 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7062434554100037 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.714179515838623 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9638258814811707 + ], + "26": [ + "vertical_and_slash", + 500, + 700, + 0.9974403381347656 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9426991939544678 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9981192946434021 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.997894287109375 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9979783892631531 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9956291317939758 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9942532777786255 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9382870197296143 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9994586706161499 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9935740232467651 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9681561589241028 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.7282125949859619 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9943098425865173 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.6868635416030884 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9987292885780334 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.8973380923271179 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.995192289352417 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.6858950257301331 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9975180625915527 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9925879836082458 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.9210449457168579 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9997319579124451 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9917876720428467 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.6827043890953064 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9853598475456238 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.9982296824455261 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9968504905700684 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9502132534980774 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9775718450546265 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7796767354011536 + ], + "24": [ + "vertical_and_slash", + 100, + 800, + 0.76953125 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9922406673431396 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9383352994918823 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.8227579593658447 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9977031946182251 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.9920535087585449 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9953324198722839 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6449344754219055 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.652643620967865 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9490343928337097 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.6628352403640747 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9904823899269104 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.7459101676940918 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.8223817348480225 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.8061914443969727 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9979249238967896 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.956091582775116 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9958360195159912 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.6650176644325256 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.946822464466095 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8161380887031555 + ], + "13": [ + "vertical_and_slash", + 500, + 700, + 0.9891030192375183 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9935292601585388 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7558044195175171 + ], + "16": [ + "vertical_and_slash", + 500, + 700, + 0.9654762744903564 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.57756507396698 + ], + "18": [ + "vertical_and_slash", + 500, + 700, + 0.9946773648262024 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.6871338486671448 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9582348465919495 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.8474758267402649 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.9841222763061523 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9883868098258972 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.8085432052612305 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.8961470127105713 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.7980116009712219 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9630904197692871 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9312359094619751 + ], + "29": [ + "vertical_and_slash", + 500, + 700, + 0.7192952036857605 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9928464889526367 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6763450503349304 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9938258528709412 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9992395043373108 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9985423684120178 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5909091234207153 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.9045535326004028 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.5327474474906921 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7697188258171082 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9967195987701416 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7490115761756897 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9701148271560669 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9769464135169983 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9512021541595459 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.945134162902832 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9831423759460449 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5774247050285339 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.992691695690155 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9788296222686768 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.9921766519546509 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9580382108688354 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.8484612703323364 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9827266931533813 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.6860563158988953 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9984534978866577 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.7839770913124084 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.7702468633651733 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9819672107696533 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.7506661415100098 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.8615745902061462 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9768449068069458 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9400221109390259 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.99903404712677 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.814863383769989 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9924616813659668 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7962270975112915 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.8446550965309143 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.986892819404602 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9911617636680603 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9796133637428284 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9944987893104553 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.989579975605011 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9553632736206055 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9968169331550598 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9912428855895996 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9946438074111938 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.960080087184906 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9448127150535583 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.992832362651825 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9532816410064697 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9048861265182495 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.994652509689331 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9620020389556885 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.6745208501815796 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.5609733462333679 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.8874548673629761 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.6557503342628479 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.6937400102615356 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9920935034751892 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9974891543388367 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.771544873714447 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9795010685920715 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9977417588233948 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.9876147508621216 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.7052534222602844 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9702804684638977 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.8113299608230591 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.75982666015625 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9722777009010315 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.6454917788505554 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9591151475906372 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9245743751525879 + ], + "6": [ + "vertical_and_slash", + 100, + 750, + 0.672450602054596 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.8690704703330994 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.77734375 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9322983622550964 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6070910692214966 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.8282676935195923 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7981728315353394 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9669613838195801 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.7259189486503601 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.7247616052627563 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.927749514579773 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.9843335747718811 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.939363956451416 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.9057216048240662 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.8851215839385986 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9476007223129272 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.8492496013641357 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9673789143562317 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9385042190551758 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9418122172355652 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.890839159488678 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9094856977462769 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9846156239509583 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.926313042640686 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8919520974159241 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.7341398596763611 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9748796820640564 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.918652355670929 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.5971909165382385 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.7076131105422974 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.84599369764328 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.548804521560669 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6747919917106628 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.9607553482055664 + ], + "8": [ + "vertical_and_slash", + 500, + 700, + 0.9529062509536743 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9745543003082275 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.5900142788887024 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.6382952332496643 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8348908424377441 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.8871970772743225 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6412845849990845 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9772831201553345 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.7201339602470398 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9578301310539246 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.97627192735672 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.802989661693573 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9934332370758057 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.988477885723114 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.977401614189148 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.8183449506759644 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.8428875803947449 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9019399285316467 + ], + "26": [ + "vertical_and_slash", + 500, + 700, + 0.850577175617218 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.8797270655632019 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9850448369979858 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.7459467649459839 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.6797484159469604 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.7080639004707336 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.684343695640564 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6700385808944702 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6516374945640564 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9504920840263367 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.9940043687820435 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8019426465034485 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.67058926820755 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9486842751502991 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9951285719871521 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.7573692202568054 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9824524521827698 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.6431483030319214 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7401513457298279 + ], + "13": [ + "vertical_and_slash", + 500, + 700, + 0.9602683186531067 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5100213289260864 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.620374321937561 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9879629015922546 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9683711528778076 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9918504953384399 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9663371443748474 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.9958233833312988 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.731646716594696 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9967772364616394 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7126162052154541 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.73592609167099 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9757681488990784 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.821488082408905 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.940662145614624 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.7413780689239502 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.855651319026947 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.957796573638916 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9116591215133667 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9933403730392456 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9730038642883301 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9607142210006714 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9766934514045715 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.8421128988265991 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9946009516716003 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6585829257965088 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.8131377100944519 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9548000693321228 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9604054093360901 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9921189546585083 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9823326468467712 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9664893746376038 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.991500735282898 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9875308275222778 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9863171577453613 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9744452238082886 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.8115602731704712 + ], + "18": [ + "vertical_and_slash", + 500, + 700, + 0.8509863018989563 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.7890191674232483 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9685430526733398 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9879798293113708 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8662706017494202 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9373542666435242 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9943336248397827 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9302772283554077 + ], + "26": [ + "vertical_and_slash", + 500, + 700, + 0.977992832660675 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.932984471321106 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9221320152282715 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6031616926193237 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9101414680480957 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.9133468270301819 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9161027669906616 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8264802694320679 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9612565636634827 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.8154476881027222 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.870027482509613 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.8858229517936707 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.732231616973877 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.8734871745109558 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7596957683563232 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.8638753890991211 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.8574302792549133 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.7212074398994446 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8646799921989441 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.6482542753219604 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.83883136510849 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8585593104362488 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9678143858909607 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.7909027934074402 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.8701796531677246 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.7488895654678345 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.8781315088272095 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.7022296190261841 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.7967407703399658 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.7607766389846802 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.5860172510147095 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.8311766982078552 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.9960868954658508 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9676917195320129 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9716269373893738 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.7922632098197937 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9029964804649353 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.7617987990379333 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7531226873397827 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9829939007759094 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9867648482322693 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.7495420575141907 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8411062359809875 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9238793253898621 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8017065525054932 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.949862539768219 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7137445211410522 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8886378407478333 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6042924523353577 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9610161781311035 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.6790781617164612 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8115764856338501 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.8262994885444641 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.5916704535484314 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.817267894744873 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.8543776869773865 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.875899612903595 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.804548442363739 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.7843059301376343 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.9910483360290527 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.7706062197685242 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.8355559706687927 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.8440632820129395 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.8883750438690186 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.5762335062026978 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.6402088403701782 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9510595202445984 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.6832257509231567 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8940309882164001 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.6904930472373962 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9712978005409241 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.715650200843811 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9092068076133728 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5512199401855469 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9947230815887451 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.942888617515564 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8079208135604858 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934179186820984 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9797422885894775 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9305815100669861 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.988544762134552 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9259185791015625 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9055561423301697 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9950724244117737 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9936328530311584 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7676860094070435 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9986890554428101 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9774577021598816 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9445663094520569 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9051742553710938 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9963957071304321 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.6996874809265137 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.9648709297180176 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.8904076814651489 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.6945720314979553 + ], + "25": [ + "vertical_and_slash", + 500, + 700, + 0.9790938496589661 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9644208550453186 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9994762539863586 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.8030149936676025 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.9580996632575989 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9973891973495483 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9492622017860413 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9596408009529114 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9903782606124878 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.985963761806488 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.8823582530021667 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.78907310962677 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.7436686754226685 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9900132417678833 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9172103404998779 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9941288232803345 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9992177486419678 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9879436492919922 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9889773726463318 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9961833953857422 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.992790699005127 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.879410445690155 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9506082534790039 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9580295085906982 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9992238879203796 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.8985172510147095 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9653815031051636 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.8116862177848816 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.8597212433815002 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8310312032699585 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9763169288635254 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9825117588043213 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9893282651901245 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9984802007675171 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.9936201572418213 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.966774582862854 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9496166110038757 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9060468077659607 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9787821769714355 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.98621666431427 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9871621131896973 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.998849630355835 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9951006174087524 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9684059619903564 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9988571405410767 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9973605275154114 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9547178745269775 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9439932107925415 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.761074960231781 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9970728158950806 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9856483340263367 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9948339462280273 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9889204502105713 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.991661548614502 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9836190938949585 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9960607290267944 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9184228181838989 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9532352685928345 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9791851043701172 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9993872046470642 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.8049466013908386 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9487335681915283 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9942803382873535 + ], + "24": [ + "vertical_and_slash", + 500, + 700, + 0.9795239567756653 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9796655774116516 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.9073926210403442 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8279032707214355 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.928718090057373 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9749822616577148 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.8956699371337891 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.7362069487571716 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.948788583278656 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9491262435913086 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.8548357486724854 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9972718954086304 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.995877742767334 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9962888360023499 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.950820803642273 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.940021276473999 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.97650146484375 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9919331669807434 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.7328993082046509 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.992123007774353 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9887773394584656 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9808080792427063 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.9803031086921692 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7681567072868347 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9663677215576172 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9978005886077881 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9956686496734619 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9876529574394226 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9083545207977295 + ], + "21": [ + "vertical_and_slash", + 500, + 700, + 0.9825258851051331 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9929331541061401 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.9937247037887573 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9849582314491272 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9303218722343445 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.9582874774932861 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9650914669036865 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9943416714668274 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9150344133377075 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8997454047203064 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9907521605491638 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.9935974478721619 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9944353103637695 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9993149638175964 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9968475103378296 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.7440210580825806 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.99021315574646 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9938840270042419 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.93567955493927 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9877127408981323 + ], + "9": [ + "vertical_and_slash", + 100, + 750, + 0.9873458743095398 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9998465776443481 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9908269047737122 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9990859031677246 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.8770437836647034 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.9989800453186035 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9958862662315369 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9739276766777039 + ], + "17": [ + "vertical_and_slash", + 3500, + 100, + 0.9918177723884583 + ], + "18": [ + "vertical_and_slash", + 100, + 750, + 0.9817363023757935 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9980490207672119 + ], + "20": [ + "vertical_and_slash", + 500, + 700, + 0.9854499101638794 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.9956621527671814 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9646536111831665 + ], + "23": [ + "vertical_and_slash", + 3500, + 100, + 0.8399244546890259 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9599056243896484 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9969561100006104 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.8741656541824341 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.9881818890571594 + ], + "28": [ + "vertical_and_slash", + 100, + 750, + 0.9986366629600525 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.8203835487365723 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.916657567024231 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9909099340438843 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.9891338348388672 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9982934594154358 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.7269482016563416 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.97837895154953 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9985787868499756 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9378053545951843 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.8923203945159912 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9933837056159973 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9831008911132812 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9890069961547852 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9977155923843384 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9636794328689575 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9993752837181091 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9918390512466431 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9965466856956482 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.979774534702301 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9794058799743652 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.998450517654419 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.7310967445373535 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9953096508979797 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9857947826385498 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.987230658531189 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9985311031341553 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9923253655433655 + ], + "24": [ + "vertical_and_slash", + 100, + 750, + 0.9921882152557373 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9417669773101807 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.9951248168945312 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.9957342743873596 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.8214721083641052 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9924106001853943 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9996931552886963 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9912320375442505 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9739575982093811 + ], + "1": [ + "vertical_and_slash", + 100, + 750, + 0.9948337078094482 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9912586808204651 + ], + "3": [ + "vertical_and_slash", + 100, + 750, + 0.9957679510116577 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9885770678520203 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.9814969301223755 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9366181492805481 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.9948270320892334 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9994694590568542 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9889045357704163 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9929386973381042 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9881108999252319 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9796789288520813 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9889234304428101 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9873424768447876 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9901317358016968 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9893797039985657 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9819779992103577 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.989522397518158 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9819537997245789 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.9925962686538696 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9989944696426392 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.9997721314430237 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9876223802566528 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9952347874641418 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9843642711639404 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.9960111975669861 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.6954624652862549 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.970451295375824 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.991379976272583 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.8738142848014832 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9786747694015503 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.997829258441925 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.990919291973114 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9968075156211853 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9982627630233765 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9957785606384277 + ], + "5": [ + "vertical_and_slash", + 100, + 750, + 0.974946141242981 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.995257556438446 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.8554062247276306 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9880555272102356 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9956945776939392 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9789673089981079 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.991595447063446 + ], + "12": [ + "vertical_and_slash", + 100, + 750, + 0.9686179161071777 + ], + "13": [ + "vertical_and_slash", + 100, + 750, + 0.9943218231201172 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9565165042877197 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9816807508468628 + ], + "16": [ + "vertical_and_slash", + 500, + 700, + 0.9969928860664368 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.9579240679740906 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.7692553400993347 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9934751987457275 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.996086597442627 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9795346260070801 + ], + "22": [ + "vertical_and_slash", + 3500, + 100, + 0.9099371433258057 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9606084823608398 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9944002032279968 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9969326257705688 + ], + "26": [ + "vertical_and_slash", + 3500, + 100, + 0.943459689617157 + ], + "27": [ + "vertical_and_slash", + 3500, + 100, + 0.9907713532447815 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9855557084083557 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.911635160446167 + ], + "30": [ + "vertical_and_slash", + 3500, + 100, + 0.9951326847076416 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9821126461029053 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9654005765914917 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.999093770980835 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.99101722240448 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9934068918228149 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9982517957687378 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9996331334114075 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9952266216278076 + ], + "7": [ + "vertical_and_slash", + 100, + 750, + 0.880059003829956 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9872965216636658 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9972522854804993 + ], + "10": [ + "vertical_and_slash", + 100, + 750, + 0.9663414359092712 + ], + "11": [ + "vertical_and_slash", + 500, + 700, + 0.9989503622055054 + ], + "12": [ + "vertical_and_slash", + 500, + 700, + 0.9980217218399048 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9978732466697693 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.990183413028717 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.9975069761276245 + ], + "16": [ + "vertical_and_slash", + 100, + 750, + 0.9249386787414551 + ], + "17": [ + "vertical_and_slash", + 500, + 700, + 0.9966362118721008 + ], + "18": [ + "vertical_and_slash", + 500, + 700, + 0.9951748847961426 + ], + "19": [ + "vertical_and_slash", + 100, + 750, + 0.9986919164657593 + ], + "20": [ + "vertical_and_slash", + 3500, + 100, + 0.9912869930267334 + ], + "21": [ + "vertical_and_slash", + 100, + 750, + 0.9970594644546509 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.998475193977356 + ], + "23": [ + "vertical_and_slash", + 500, + 700, + 0.9993215799331665 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9980448484420776 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.9916543364524841 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.980556845664978 + ], + "27": [ + "vertical_and_slash", + 500, + 700, + 0.9921435117721558 + ], + "28": [ + "vertical_and_slash", + 3500, + 100, + 0.9989830255508423 + ], + "29": [ + "vertical_and_slash", + 3500, + 100, + 0.9973907470703125 + ], + "30": [ + "vertical_and_slash", + 500, + 700, + 0.9833565354347229 + ], + "31": [ + "vertical_and_slash", + 100, + 750, + 0.9759599566459656 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 750, + 0.9991001486778259 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9909220933914185 + ], + "2": [ + "vertical_and_slash", + 100, + 750, + 0.9492173194885254 + ], + "3": [ + "vertical_and_slash", + 500, + 700, + 0.9961900115013123 + ], + "4": [ + "vertical_and_slash", + 100, + 750, + 0.8713532090187073 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9998117089271545 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.998404324054718 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9998683333396912 + ], + "8": [ + "vertical_and_slash", + 100, + 750, + 0.9000679850578308 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9757489562034607 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.9937180876731873 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9938128590583801 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9991152882575989 + ], + "13": [ + "vertical_and_slash", + 500, + 700, + 0.9996658563613892 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.9966751933097839 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.999762773513794 + ], + "16": [ + "vertical_and_slash", + 3500, + 100, + 0.9985853433609009 + ], + "17": [ + "vertical_and_slash", + 100, + 750, + 0.7297303080558777 + ], + "18": [ + "vertical_and_slash", + 3500, + 100, + 0.9985169768333435 + ], + "19": [ + "vertical_and_slash", + 3500, + 100, + 0.9998067617416382 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.8747161030769348 + ], + "21": [ + "vertical_and_slash", + 3500, + 100, + 0.9923343658447266 + ], + "22": [ + "vertical_and_slash", + 500, + 700, + 0.9940261840820312 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9998047351837158 + ], + "24": [ + "vertical_and_slash", + 3500, + 100, + 0.9748029112815857 + ], + "25": [ + "vertical_and_slash", + 3500, + 100, + 0.9991946816444397 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.8475115299224854 + ], + "27": [ + "vertical_and_slash", + 100, + 750, + 0.9997408390045166 + ], + "28": [ + "vertical_and_slash", + 500, + 700, + 0.9990043044090271 + ], + "29": [ + "vertical_and_slash", + 100, + 750, + 0.8996012806892395 + ], + "30": [ + "vertical_and_slash", + 100, + 750, + 0.9358092546463013 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.9944736361503601 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8248796463012695 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7729880213737488 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.7208629250526428 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.7399256825447083 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.7422590851783752 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.949510395526886 + ], + "6": [ + "vertical_and_slash", + 100, + 750, + 0.8432893753051758 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9128398299217224 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.8389910459518433 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8408676981925964 + ], + "10": [ + "vertical_and_slash", + 500, + 700, + 0.8649052381515503 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9523816704750061 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.7945832014083862 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.8890054225921631 + ], + "14": [ + "vertical_and_slash", + 100, + 750, + 0.755133330821991 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9407838582992554 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.978395938873291 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.841075599193573 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.7965966463088989 + ], + "19": [ + "vertical_and_slash", + 500, + 700, + 0.7598339319229126 + ], + "20": [ + "vertical_and_slash", + 100, + 750, + 0.7436662316322327 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.8721699714660645 + ], + "22": [ + "vertical_and_slash", + 100, + 750, + 0.872313916683197 + ], + "23": [ + "vertical_and_slash", + 100, + 750, + 0.9902216792106628 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.7798812985420227 + ], + "25": [ + "vertical_and_slash", + 1000, + 6096, + 0.9631245732307434 + ], + "26": [ + "vertical_and_slash", + 100, + 750, + 0.845567524433136 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8043644428253174 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9540744423866272 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.9055390357971191 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9507457613945007 + ], + "31": [ + "vertical_and_slash", + 3500, + 100, + 0.827296793460846 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9915655255317688 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9915629029273987 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9245554804801941 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9129937291145325 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9958124160766602 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.806128203868866 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8036627769470215 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9017021656036377 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9740359783172607 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.968455970287323 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9159662127494812 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8428217172622681 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9581736326217651 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.970429539680481 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9745091199874878 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.894728422164917 + ], + "16": [ + "vertical_and_slash", + 1000, + 6096, + 0.9496183395385742 + ], + "17": [ + "vertical_and_slash", + 1000, + 6096, + 0.9325363039970398 + ], + "18": [ + "vertical_and_slash", + 1000, + 6096, + 0.9315022826194763 + ], + "19": [ + "vertical_and_slash", + 1000, + 6096, + 0.8017199635505676 + ], + "20": [ + "vertical_and_slash", + 1000, + 6096, + 0.9268004298210144 + ], + "21": [ + "vertical_and_slash", + 1000, + 6096, + 0.8929623365402222 + ], + "22": [ + "vertical_and_slash", + 1000, + 6096, + 0.8346715569496155 + ], + "23": [ + "vertical_and_slash", + 1000, + 6096, + 0.8660512566566467 + ], + "24": [ + "vertical_and_slash", + 1000, + 6096, + 0.9183820486068726 + ], + "25": [ + "vertical_and_slash", + 100, + 750, + 0.8379555344581604 + ], + "26": [ + "vertical_and_slash", + 1000, + 6096, + 0.911184549331665 + ], + "27": [ + "vertical_and_slash", + 1000, + 6096, + 0.8829504251480103 + ], + "28": [ + "vertical_and_slash", + 1000, + 6096, + 0.9138942956924438 + ], + "29": [ + "vertical_and_slash", + 1000, + 6096, + 0.872784435749054 + ], + "30": [ + "vertical_and_slash", + 1000, + 6096, + 0.9097738862037659 + ], + "31": [ + "vertical_and_slash", + 1000, + 6096, + 0.9275451898574829 + ] + } +] \ No newline at end of file diff --git a/minference/configs/Qwen2.5_3B_flex_0.90.json b/minference/configs/Qwen2.5_3B_flex_0.90.json new file mode 100644 index 0000000..36f48de --- /dev/null +++ b/minference/configs/Qwen2.5_3B_flex_0.90.json @@ -0,0 +1,616 @@ +[ + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0] + }, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.9, 4096], [0.9, 4096], 0.0] + } +] \ No newline at end of file diff --git a/minference/configs/Qwen2.5_3B_flex_0.95.json b/minference/configs/Qwen2.5_3B_flex_0.95.json new file mode 100644 index 0000000..2cacba4 --- /dev/null +++ b/minference/configs/Qwen2.5_3B_flex_0.95.json @@ -0,0 +1,616 @@ +[ + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0] + }, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0]}, + { + "0": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "1": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "2": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "3": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "4": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "5": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "6": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "7": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "8": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "9": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "10": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "11": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "12": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "13": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "14": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0], + "15": ["flex_vertical_and_slash", [0.95, 4096], [0.95, 4096], 0.0] + } +] \ No newline at end of file diff --git a/minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json b/minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json new file mode 100644 index 0000000..fdb2922 --- /dev/null +++ b/minference/configs/Qwen2.5_3B_kv_out_v32_fit_o_best_pattern.json @@ -0,0 +1,3530 @@ +[ + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9872207641601562 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9784929752349854 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.7595849633216858 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5381054878234863 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9863664507865906 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9912353157997131 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7160804867744446 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9073030352592468 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9161261916160583 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9784228205680847 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9789554476737976 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.7867575883865356 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 1.0000020265579224 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9307636618614197 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6895971298217773 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7491968870162964 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9951784610748291 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9995638132095337 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9881429076194763 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9951925277709961 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.928062379360199 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9995735883712769 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7747997045516968 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9893845319747925 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9328423738479614 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.7227432131767273 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6669939160346985 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.955822765827179 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.6157850623130798 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8225603103637695 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6094294786453247 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7056097388267517 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8943619728088379 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6963416337966919 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9629008173942566 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.866447389125824 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5282332897186279 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9982369542121887 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9979943633079529 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9979172945022583 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.942166268825531 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9923297166824341 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9751147031784058 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8978350758552551 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8243312239646912 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9721394181251526 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.93731689453125 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9794054627418518 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7129239439964294 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9804909229278564 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9835291504859924 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9893701076507568 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9409563541412354 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8059223890304565 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6498631238937378 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.984248697757721 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7962363362312317 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.868658721446991 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.8754917979240417 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.8955696821212769 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9082641005516052 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8178426623344421 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.8291682004928589 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8030994534492493 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9872316122055054 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9957523345947266 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9542893171310425 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9896659255027771 + ], + "4": [ + "vertical_and_slash", + 30, + 800, + 0.9950734376907349 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8253392577171326 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8497561812400818 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.9906441569328308 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9813032746315002 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9744712114334106 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.809856116771698 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9735696911811829 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8704012036323547 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.965289294719696 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9198070168495178 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9896268844604492 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9742131233215332 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9894583821296692 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9873966574668884 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9833617210388184 + ], + "4": [ + "vertical_and_slash", + 30, + 800, + 0.9105245471000671 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.734259843826294 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.9877724051475525 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.9896732568740845 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8814374208450317 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9992178678512573 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9980509877204895 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9092496037483215 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8476247191429138 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9594801664352417 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8697033524513245 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9915944933891296 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 0.95703125 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.5575676560401917 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9087390303611755 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.9765625 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9809911847114563 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.5668226480484009 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.978988766670227 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9430385828018188 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.7421875 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9403589963912964 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9878966808319092 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9328376650810242 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.8343550562858582 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.959410548210144 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9758256673812866 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.933525562286377 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6686467528343201 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9911661744117737 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.5714142322540283 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8095535635948181 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9295337796211243 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.965142011642456 + ], + "7": [ + "vertical_and_slash", + 100, + 800, + 0.984375 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.6681234240531921 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8945913910865784 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6786034107208252 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9802227020263672 + ], + "12": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9555180072784424 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9749921560287476 + ], + "15": [ + "vertical_and_slash", + 100, + 800, + 0.9921875 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8633137345314026 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9670861959457397 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.8114507794380188 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9675626158714294 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9122037291526794 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9735010862350464 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9719548225402832 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9510305523872375 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.72588711977005 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9162058234214783 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.924541175365448 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9764450788497925 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9914652705192566 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.7905170321464539 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9880024790763855 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9844828844070435 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9420890212059021 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.986860454082489 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9866741299629211 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.7924719452857971 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.890625 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9799174666404724 + ], + "6": [ + "vertical_and_slash", + 500, + 700, + 0.9916488528251648 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.993992030620575 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9875913262367249 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.980666995048523 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9952725172042847 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9930256605148315 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9982009530067444 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9283435940742493 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9926015138626099 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9950297474861145 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.989981472492218 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9963133335113525 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934775233268738 + ], + "3": [ + "vertical_and_slash", + 30, + 800, + 0.9720687866210938 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.957918107509613 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8732808828353882 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9542519450187683 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934373497962952 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9688447713851929 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9413594007492065 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9796752333641052 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9452784657478333 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9403716921806335 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9615768790245056 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.760350227355957 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.8651421666145325 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9908298254013062 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.818602979183197 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9621649980545044 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.7571771144866943 + ], + "4": [ + "vertical_and_slash", + 500, + 700, + 0.9852563738822937 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9928317070007324 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7078589797019958 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9356410503387451 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8330709338188171 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9761743545532227 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9746711254119873 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9730107188224792 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9346276521682739 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9564436674118042 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9630159735679626 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9073519706726074 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8938577175140381 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9779987931251526 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9370880126953125 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8135497570037842 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9516273736953735 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.874685525894165 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8642981648445129 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9647448658943176 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.7395755052566528 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9181990027427673 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.7171043157577515 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8813474774360657 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9861239790916443 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9770793318748474 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9610176086425781 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9338920712471008 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9416564106941223 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.991210401058197 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9882159233093262 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.736187219619751 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8281543254852295 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9498741030693054 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9149094820022583 + ], + "7": [ + "vertical_and_slash", + 30, + 800, + 0.9877381920814514 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.963042140007019 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9886980652809143 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9890977740287781 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9927736520767212 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.992214024066925 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9973360300064087 + ], + "14": [ + "vertical_and_slash", + 500, + 700, + 0.7854554653167725 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9945513606071472 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9653095602989197 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8165479898452759 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9046997427940369 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8763824701309204 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8182893991470337 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9363685846328735 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 0.734375 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9775119423866272 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9895726442337036 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9737070202827454 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.996068000793457 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.997230589389801 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9761516451835632 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9949500560760498 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9936166405677795 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9869099855422974 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9887930154800415 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.9879409670829773 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.957123875617981 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 0.98828125 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.4647713899612427 + ], + "5": [ + "vertical_and_slash", + 30, + 800, + 0.9909580945968628 + ], + "6": [ + "vertical_and_slash", + 30, + 800, + 0.9757564067840576 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5521421432495117 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7780187129974365 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9887256026268005 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9927332401275635 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9805054664611816 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9525687098503113 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9362225532531738 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9488365054130554 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9525135159492493 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9934394955635071 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9532470703125 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9188738465309143 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9849047660827637 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9228449463844299 + ], + "5": [ + "vertical_and_slash", + 100, + 800, + 0.9765625 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9707450866699219 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9929892420768738 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.964901864528656 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.911367654800415 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9818339943885803 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9837478399276733 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9615333080291748 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9666763544082642 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9545288681983948 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9649417400360107 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9158878326416016 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9285635948181152 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9884502291679382 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8363761901855469 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9531059265136719 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9884499907493591 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9524633884429932 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9358732104301453 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8582175374031067 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.3922925889492035 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.5313135385513306 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.953125 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8960450291633606 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.90234375 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6443539261817932 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.829773485660553 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.5914504528045654 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8983972668647766 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.930306077003479 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8418871164321899 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.9140625 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8820360898971558 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7236220836639404 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5575733184814453 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9963138699531555 + ], + "9": [ + "vertical_and_slash", + 30, + 800, + 0.9883040189743042 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9783397912979126 + ], + "11": [ + "vertical_and_slash", + 30, + 800, + 0.9933704733848572 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9880709648132324 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9847649931907654 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9817938804626465 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9038569927215576 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9423895478248596 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 0.99609375 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9898338913917542 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8854114413261414 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9820103049278259 + ], + "5": [ + "vertical_and_slash", + 100, + 800, + 0.98828125 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6622527837753296 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5836654901504517 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.7537979483604431 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.9120598435401917 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9560258388519287 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 1.0 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9937338829040527 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9111098051071167 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.9436591863632202 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9429888129234314 + ] + }, + { + "0": [ + "vertical_and_slash", + 100, + 800, + 0.8515625 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.620391845703125 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6681154370307922 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9479513764381409 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5289033651351929 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.43187281489372253 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8812884092330933 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.79611736536026 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9973558783531189 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.734375 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6380698680877686 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.9609375 + ], + "12": [ + "vertical_and_slash", + 100, + 800, + 0.8359375 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8900502324104309 + ], + "14": [ + "vertical_and_slash", + 100, + 800, + 0.8984375 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.7215483784675598 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8019538521766663 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8427147269248962 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.6292986273765564 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9548527002334595 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8857505321502686 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8712131381034851 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8540765643119812 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.5264020562171936 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8968150615692139 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.6485419273376465 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.8069987893104553 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8020429015159607 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8054234981536865 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.725652813911438 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.46037647128105164 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9644275903701782 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7860593199729919 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.8588574528694153 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9157812595367432 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8626066446304321 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8797851800918579 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.7836940884590149 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.8759902715682983 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.8079484701156616 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.7734375 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.80859375 + ], + "10": [ + "vertical_and_slash", + 100, + 800, + 0.91796875 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.93359375 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.5005961656570435 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.8044103384017944 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.6477628946304321 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.5467575192451477 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9432184100151062 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9644391536712646 + ], + "2": [ + "vertical_and_slash", + 100, + 800, + 0.83203125 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9638855457305908 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9378725290298462 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8560249209403992 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9667811989784241 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9581833481788635 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9911800622940063 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.7890625 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9970740675926208 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.9453125 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.8719268441200256 + ], + "13": [ + "vertical_and_slash", + 30, + 800, + 0.9973757863044739 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.9817723631858826 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9943931102752686 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.7833037376403809 + ], + "1": [ + "vertical_and_slash", + 30, + 800, + 0.985275149345398 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.8808413147926331 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.6342004537582397 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.6783415675163269 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.8690900206565857 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9161447286605835 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9136660099029541 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 0.9813501834869385 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9898714423179626 + ], + "10": [ + "vertical_and_slash", + 30, + 800, + 0.9806197285652161 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.8513939380645752 + ], + "12": [ + "vertical_and_slash", + 30, + 800, + 0.9925402402877808 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.98046875 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9726189374923706 + ], + "15": [ + "vertical_and_slash", + 30, + 800, + 0.9964751601219177 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.8871311545372009 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.6332476735115051 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.6504296660423279 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8371340036392212 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.5707467198371887 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.49299511313438416 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.6483507752418518 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9506110548973083 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.8480803966522217 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.82421875 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9129166007041931 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.5975048542022705 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8599441647529602 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.94140625 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.7465125918388367 + ], + "15": [ + "vertical_and_slash", + 500, + 700, + 0.9268635511398315 + ] + }, + { + "0": [ + "vertical_and_slash", + 30, + 800, + 0.9838484525680542 + ], + "1": [ + "vertical_and_slash", + 100, + 800, + 0.96875 + ], + "2": [ + "vertical_and_slash", + 500, + 700, + 0.9804697632789612 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9868910908699036 + ], + "4": [ + "vertical_and_slash", + 100, + 800, + 0.953125 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.8383246660232544 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.7157800793647766 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9093014001846313 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.82421875 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.951353132724762 + ], + "11": [ + "vertical_and_slash", + 100, + 800, + 0.98046875 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.5795325040817261 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9869730472564697 + ], + "14": [ + "vertical_and_slash", + 100, + 800, + 0.953125 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.6627390384674072 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9301889538764954 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9623850584030151 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.925979495048523 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9278422594070435 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9768559336662292 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.9585739374160767 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.9396020174026489 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.8997651934623718 + ], + "8": [ + "vertical_and_slash", + 30, + 800, + 1.000001072883606 + ], + "9": [ + "vertical_and_slash", + 100, + 800, + 0.89453125 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.9783903360366821 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9423955678939819 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.6527382135391235 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9103218913078308 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.829939067363739 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.982806921005249 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9865686893463135 + ], + "1": [ + "vertical_and_slash", + 500, + 700, + 0.9607634544372559 + ], + "2": [ + "vertical_and_slash", + 30, + 800, + 0.9966762065887451 + ], + "3": [ + "vertical_and_slash", + 100, + 800, + 0.75 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.8039277791976929 + ], + "5": [ + "vertical_and_slash", + 500, + 700, + 0.9704546332359314 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9315423965454102 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.8952665328979492 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.8250367045402527 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.953521728515625 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.8773703575134277 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9780517816543579 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9766848683357239 + ], + "13": [ + "vertical_and_slash", + 1000, + 6096, + 0.9704862236976624 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9836187958717346 + ], + "15": [ + "vertical_and_slash", + 1000, + 6096, + 0.9668617844581604 + ] + }, + { + "0": [ + "vertical_and_slash", + 1000, + 6096, + 0.9616305828094482 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9781763553619385 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.9224084615707397 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.9049281477928162 + ], + "4": [ + "vertical_and_slash", + 1000, + 6096, + 0.9796241521835327 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9083701968193054 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9923230409622192 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9486480951309204 + ], + "8": [ + "vertical_and_slash", + 500, + 700, + 0.9276089668273926 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9481350183486938 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9749733805656433 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9456705451011658 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9752734303474426 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.97265625 + ], + "14": [ + "vertical_and_slash", + 30, + 800, + 0.955278754234314 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9254546761512756 + ] + }, + { + "0": [ + "vertical_and_slash", + 500, + 700, + 0.9746971726417542 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.9719505310058594 + ], + "2": [ + "vertical_and_slash", + 1000, + 6096, + 0.8964815735816956 + ], + "3": [ + "vertical_and_slash", + 1000, + 6096, + 0.8442646265029907 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9568673968315125 + ], + "5": [ + "vertical_and_slash", + 1000, + 6096, + 0.94114750623703 + ], + "6": [ + "vertical_and_slash", + 100, + 800, + 0.96875 + ], + "7": [ + "vertical_and_slash", + 500, + 700, + 0.9370027780532837 + ], + "8": [ + "vertical_and_slash", + 100, + 800, + 0.91015625 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9786306023597717 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.7740182876586914 + ], + "11": [ + "vertical_and_slash", + 1000, + 6096, + 0.9646586179733276 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.832630455493927 + ], + "13": [ + "vertical_and_slash", + 100, + 800, + 0.94921875 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.5829520225524902 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9215387105941772 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9220553040504456 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9111120700836182 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.8448940515518188 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9479627013206482 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9360203146934509 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.806208074092865 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.8499483466148376 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9351169466972351 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.887176513671875 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9972283244132996 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.7344715595245361 + ], + "11": [ + "vertical_and_slash", + 100, + 750, + 0.9290981888771057 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9191303849220276 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9607530832290649 + ], + "14": [ + "vertical_and_slash", + 1000, + 6096, + 0.6650398373603821 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9041045308113098 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9288094639778137 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.922483503818512 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.8560412526130676 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9515807628631592 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.910995364189148 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9611515402793884 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9130395650863647 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.8693947196006775 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9321247935295105 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.7624196410179138 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9113157391548157 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.8822183012962341 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.940976083278656 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9429124593734741 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9557855129241943 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.7963366508483887 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9815316796302795 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9569784998893738 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.9643719792366028 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9694581627845764 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9616712927818298 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9564018249511719 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9240798354148865 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9618653059005737 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9393341541290283 + ], + "9": [ + "vertical_and_slash", + 3500, + 100, + 0.9590299129486084 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.9623062014579773 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9482530355453491 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.9658593535423279 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9724211096763611 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9585490822792053 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.9729295969009399 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9484567046165466 + ], + "1": [ + "vertical_and_slash", + 1000, + 6096, + 0.7411888241767883 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.8387667536735535 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.7810403108596802 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.8578725457191467 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.8868502974510193 + ], + "6": [ + "vertical_and_slash", + 1000, + 6096, + 0.7327648401260376 + ], + "7": [ + "vertical_and_slash", + 1000, + 6096, + 0.9077197313308716 + ], + "8": [ + "vertical_and_slash", + 3500, + 100, + 0.9110754728317261 + ], + "9": [ + "vertical_and_slash", + 1000, + 6096, + 0.8923226594924927 + ], + "10": [ + "vertical_and_slash", + 3500, + 100, + 0.7206286191940308 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.8617163300514221 + ], + "12": [ + "vertical_and_slash", + 3500, + 100, + 0.8827745914459229 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.7372896075248718 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9529178142547607 + ], + "15": [ + "vertical_and_slash", + 3500, + 100, + 0.8600607514381409 + ] + }, + { + "0": [ + "vertical_and_slash", + 3500, + 100, + 0.9317379593849182 + ], + "1": [ + "vertical_and_slash", + 3500, + 100, + 0.9660685658454895 + ], + "2": [ + "vertical_and_slash", + 3500, + 100, + 0.932490885257721 + ], + "3": [ + "vertical_and_slash", + 3500, + 100, + 0.9033127427101135 + ], + "4": [ + "vertical_and_slash", + 3500, + 100, + 0.9372511506080627 + ], + "5": [ + "vertical_and_slash", + 3500, + 100, + 0.9416565299034119 + ], + "6": [ + "vertical_and_slash", + 3500, + 100, + 0.9347256422042847 + ], + "7": [ + "vertical_and_slash", + 3500, + 100, + 0.9515694379806519 + ], + "8": [ + "vertical_and_slash", + 1000, + 6096, + 0.9580468535423279 + ], + "9": [ + "vertical_and_slash", + 500, + 700, + 0.9769356846809387 + ], + "10": [ + "vertical_and_slash", + 1000, + 6096, + 0.6504788398742676 + ], + "11": [ + "vertical_and_slash", + 3500, + 100, + 0.9479199051856995 + ], + "12": [ + "vertical_and_slash", + 1000, + 6096, + 0.9721158146858215 + ], + "13": [ + "vertical_and_slash", + 3500, + 100, + 0.9033265113830566 + ], + "14": [ + "vertical_and_slash", + 3500, + 100, + 0.9400415420532227 + ], + "15": [ + "vertical_and_slash", + 100, + 750, + 0.8886585831642151 + ] + } +] \ No newline at end of file diff --git a/minference/dist_ops/__init__.py b/minference/dist_ops/__init__.py new file mode 100644 index 0000000..95e6cd5 --- /dev/null +++ b/minference/dist_ops/__init__.py @@ -0,0 +1,5 @@ +from .minfer_striped import minfer_stripe_func +from .minfer_zigzag import minfer_zigzag_func +from .minfer_dr_striped import minfer_dr_stripe_func +from .moba_zigzag import moba_zigzag_func +from .xattn_zigzag import xattn_zigzag_func diff --git a/minference/dist_ops/dr_striped_attention.py b/minference/dist_ops/dr_striped_attention.py new file mode 100644 index 0000000..242f506 --- /dev/null +++ b/minference/dist_ops/dr_striped_attention.py @@ -0,0 +1,584 @@ +import os +import torch +import torch.distributed as dist + +from typing import List +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import ( + RingComm, update_out_and_lse, get_default_args, + shuffle_striped_input, recover_striped_output +) + +def get_inner_group(): + rank = dist.get_rank() + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + inner_group = [i for i in range(local_world_size)] + + rank_offset = (rank // local_world_size) * local_world_size + inner_group = [rank_offset + i for i in inner_group] + + return inner_group + +def get_outer_group(): + local_rank = int(os.environ["LOCAL_RANK"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + world_size = dist.get_world_size() + + outer_group = [] + i = local_rank + while i < world_size: + outer_group.append(i) + i += local_world_size + + return outer_group + +def stripe_fwd_inner( + process_group, + outer_step: int, + outer_rank: int, + inner_ring_list: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + inner_comm = RingComm(process_group, False, inner_ring_list) + inner_rank = int(os.environ["LOCAL_RANK"]) + num_inner_steps = len(inner_ring_list) + + out, lse = None, None + next_k, next_v = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + def forward(q_, k_, v_, dropout_p_, softmax_scale_, causal_, alibi_slopes_): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q_, + "k": k_, + "v": v_, + "dropout_p": dropout_p_, + "softmax_scale": softmax_scale_, + "causal": causal_, + "alibi_slopes": alibi_slopes_, + "return_softmax": True and dropout_p_ > 0, + } + ) + + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + if outer_step == 0 and inner_step > inner_rank: + block_out, block_lse = forward( + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], + dropout_p, softmax_scale, causal, alibi_slopes, + ) + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, slice_=(slice(None), slice(granularity, None)) + ) + else: + block_out, block_lse = forward( + q, k, v, + dropout_p, softmax_scale, causal, alibi_slopes, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + +def stripe_fwd_outer( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + inner_ring_list: List[int]=None, + outer_ring_list: List[int]=None, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" + outer_comm = RingComm(process_group, False, outer_ring_list) + + global_rank = dist.get_rank() + outer_rank = outer_ring_list.index(global_rank) + num_outer_steps = len(outer_ring_list) + + out = None + lse = None + + next_k, next_v = None, None + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + + if outer_step <= outer_rank: + block_out, block_lse = stripe_fwd_inner( + process_group, outer_step, outer_rank, inner_ring_list, + q, k, v, + softmax_scale, + granularity, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + # Before the step index goes beyond the current rank, the received KV indices are not greater than those of the Q in the current rank + # After the step index goes beyond the current rank, only the KV indices before the last granularity are no greater than those of the Q after the first granularity + # this conclusion holds after the step index goes beyond the current rank (not just step index == current rank) + block_out, block_lse = stripe_fwd_inner( + process_group, outer_step, outer_rank, inner_ring_list, + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], + softmax_scale, + granularity, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + ) + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, slice_=(slice(None), slice(granularity, None)) + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def stripe_backward_inner( + process_group, + outer_step, inner_ring_list, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert ( + causal + ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + kv_comm = RingComm(process_group, False, inner_ring_list) + d_kv_comm = RingComm(process_group, False, inner_ring_list) + + inner_rank = int(os.environ["LOCAL_RANK"]) + num_inner_step = len(inner_ring_list) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + for inner_step in range(num_inner_step): + if inner_step + 1 != num_inner_step: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + shift_causal = outer_step == 0 and inner_step > inner_rank + softmax_lse_1 = None + + def backward( + dout_, + q_, k_, v_, out_, softmax_lse_, + block_dq_buffer_, block_dk_buffer_, block_dv_buffer_, + ): + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout_, + "q": q_, + "k": k_, + "v": v_, + "out": out_, + "softmax_lse": softmax_lse_, + "dq": block_dq_buffer_, + "dk": block_dk_buffer_, + "dv": block_dv_buffer_, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + if not shift_causal: + backward( + dout, q, k, v, out, softmax_lse, + block_dq_buffer, block_dk_buffer, block_dv_buffer + ) + else: + if softmax_lse_1 is None: + # lazy init, since the last rank does not need softmax_lse_1 + softmax_lse_1 = softmax_lse[:, :, granularity:].contiguous() + backward( + dout[:, granularity:], + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], + out[:, granularity:], softmax_lse_1, + block_dq_buffer[:, granularity:], block_dk_buffer[:, :-granularity], block_dv_buffer[:, :-granularity] + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + if not shift_causal: + dq += block_dq_buffer + else: + dq[:, granularity:] += block_dq_buffer[:, granularity:] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if not shift_causal: + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dk[:, :-granularity] += block_dk_buffer[:, :-granularity] + dv[:, :-granularity] += block_dv_buffer[:, :-granularity] + + if inner_step + 1 != num_inner_step: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +def stripe_backward_outer( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + inner_ring_list: List[int], + outer_ring_list: List[int], + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert ( + causal + ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + + outer_kv_comm = RingComm(process_group, False, outer_ring_list) + outer_dkv_comm = RingComm(process_group, False, outer_ring_list) + + global_rank = dist.get_rank() + outer_rank = outer_ring_list.index(global_rank) + num_outer_steps = len(outer_ring_list) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + + softmax_lse_1 = None + outer_shift = outer_step > outer_rank + + if not outer_shift: + block_dq_buffer, block_dk_buffer, block_dv_buffer = stripe_backward_inner( + process_group, outer_step, inner_ring_list, + dout, q, k, v, out, + softmax_lse, softmax_scale, granularity, dropout_p, + causal, window_size, alibi_slopes, deterministic, + ) + else: + if softmax_lse_1 is None: + # lazy init, since the last rank does not need softmax_lse_1 + softmax_lse_1 = softmax_lse[:, :, granularity:].contiguous() + block_dq_buffer, block_dk_buffer, block_dv_buffer = stripe_backward_inner( + process_group, outer_step, inner_ring_list, + dout[:, granularity:], + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], out[:, granularity:], + softmax_lse_1, softmax_scale, granularity, dropout_p, + causal, window_size, alibi_slopes, deterministic, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + if not outer_shift: + dq += block_dq_buffer + else: + dq[:, granularity:] += block_dq_buffer + + outer_dkv_comm.wait() + + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if not outer_shift: + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dk[:, :-granularity] += block_dk_buffer + dv[:, :-granularity] += block_dv_buffer + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_dkv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_dkv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +class DRStripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + inner_ring_list = get_inner_group() # ranks in the current node, length = num of cards within this node + outer_ring_list = get_outer_group() # corresponding ranks in other nodes, length = num of nodes + + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + out, softmax_lse = stripe_fwd_outer( + group, + q, + k, + v, + softmax_scale=softmax_scale, + granularity=granularity, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + inner_ring_list=inner_ring_list, + outer_ring_list=outer_ring_list, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.inner_ring_list = inner_ring_list + ctx.outer_ring_list = outer_ring_list + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse = ctx.saved_tensors + inner_ring_list, outer_ring_list = ctx.inner_ring_list, ctx.outer_ring_list + dq, dk, dv = stripe_backward_outer( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + ctx.softmax_scale, + inner_ring_list=inner_ring_list, + outer_ring_list=outer_ring_list, + granularity=ctx.granularity, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def dr_stripe_flash_attn_qkvpacked_func( + qkv, # [B, N, 3, H, D] + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return DRStripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def dr_stripe_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return DRStripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def dr_stripe_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return DRStripeFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/minfer_dr_striped.py b/minference/dist_ops/minfer_dr_striped.py new file mode 100644 index 0000000..d5fb346 --- /dev/null +++ b/minference/dist_ops/minfer_dr_striped.py @@ -0,0 +1,727 @@ +import os +import torch +import torch.distributed as dist + +from typing import List, Tuple, Dict +from .utils import ( + RingComm, + shuffle_striped_input, recover_striped_output, + get_inner_ring, get_outer_ring +) +from minference.ops.utils import use_triton +from minference.ops.op_utils.vertical_slash_utils import build_index, extract_kv, merge_kv, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3 import block_bar_attn_fwd, block_attn_bwd, bar_attn_bwd, block_bar_attn_bwd + +# ------------------------------------------------------------------------ +# CUDA-based Implementation (Block-Sparse-Attention version) +def minfer_dr_stripe_forward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + outer_rank: int, + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + inner_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_comm.rank) + num_inner_steps = len(inner_ring) + + next_k, next_v = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + inner_offset = (inner_rank - inner_step) % num_inner_steps + offset = outer_offset * num_inner_steps + inner_offset + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[inner_offset], block_cnt[inner_offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + extract_kv(k, v, bar_k, bar_v, v_idx, v_cnt, step=offset) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_dr_stripe_forward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + granularity: int = 128, +): + outer_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_comm.rank) + num_outer_steps = len(outer_ring) + + out, lse = None, None + next_k, next_v = None, None + inner_block_masks = block_mask.chunk(num_outer_steps, dim=0) + + batch_size, _, num_qo_heads, head_dim = q.shape + max_v_size = v_idx.shape[-1] + bar_k = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + bar_v = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + outer_offset = (outer_rank - outer_step) % num_outer_steps + + out, lse = minfer_dr_stripe_forward_inner( + process_group, + outer_step, outer_offset, outer_rank, inner_ring, + q, k, v, out, lse, layer_idx, softmax_scale, + inner_block_masks[outer_offset], bar_idx, bar_cnt, + v_idx, v_cnt, bar_k, bar_v, + granularity, + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + return out, lse, bar_k, bar_v + + +def minfer_dr_stripe_backward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + outer_rank: int, + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + bar_dk: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_dv: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + inner_kv_comm = RingComm(process_group, False, inner_ring) + inner_d_kv_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_kv_comm.rank) + num_inner_steps = len(inner_ring) + + dk, dv = None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_kv_comm.send_recv_kv(k, v) + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # Update dQ, dK, dV + if inner_step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + inner_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dk += step_dk + dv += step_dv + dq += step_dq + merge_kv(dk, dv, bar_dk, bar_dv, v_idx, v_cnt, step=offset) + if inner_step + 1 != num_inner_steps: + inner_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = inner_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + inner_d_kv_comm.wait() + return dq, next_dk, next_dv + + +def minfer_dr_stripe_backward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_pos: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + outer_kv_comm = RingComm(process_group, False, outer_ring) + outer_d_kv_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_kv_comm.rank) + num_outer_steps = len(outer_ring) + + dq, dk, dv = None, None, None + next_k, next_v = None, None + next_dk, next_dv = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + # Bar Mask + full_bar_cnt = torch.stack([bar_cnt[..., 0], bar_cnt[..., -1]], dim=-1) + dq, bar_dk, bar_dv = bar_attn_bwd( + dout, q, bar_k, bar_v, out, None, None, None, + softmax_lse, softmax_scale, + bar_pos, full_bar_cnt, + granularity=granularity, + deterministic=False, + step=0, + ) + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + outer_offset = (outer_rank - outer_step) % num_outer_steps + + dq, step_dk, step_dv = minfer_dr_stripe_backward_inner( + process_group, outer_step, outer_offset, outer_rank, inner_ring, + dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, v_idx, v_cnt, + dq, bar_dk, bar_dv, + granularity, + ) + + if outer_step == 0: + # TODO: check if float32 is necessary + dk, dv = step_dk, step_dv + else: + outer_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dk += step_dk + dv += step_dv + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + + +# --------------------------------------------------------------------- +# Purely Triton-based Implementation +def minfer_dr_stripe_triton_forward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + inner_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_comm.rank) + num_inner_steps = len(inner_ring) + + next_k, next_v = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if inner_step + 1 != num_inner_steps: + inner_comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_dr_stripe_triton_forward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + outer_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_comm.rank) + num_outer_steps = len(outer_ring) + + out = None + lse = None + + next_k, next_v = None, None + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_comm.send_recv_kv(k, v) + + outer_offset = (outer_rank - outer_step) % num_outer_steps + out, lse = minfer_dr_stripe_triton_forward_inner( + process_group, outer_step, outer_offset, inner_ring, + q, k, v, out, lse, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity, + ) + + if outer_step + 1 != num_outer_steps: + outer_comm.wait() + k, v = next_k, next_v + + # out = out.to(q.dtype) + # lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def minfer_dr_stripe_triton_backward_inner( + process_group: dist.ProcessGroup, + outer_step: int, + outer_offset: int, + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + inner_kv_comm = RingComm(process_group, False, inner_ring) + inner_d_kv_comm = RingComm(process_group, False, inner_ring) + inner_rank = inner_ring.index(inner_kv_comm.rank) + num_inner_steps = len(inner_ring) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for inner_step in range(num_inner_steps): + if inner_step + 1 != num_inner_steps: + next_k, next_v = inner_kv_comm.send_recv_kv(k, v) + + block_causal = (outer_step == 0) and (inner_step == 0) + offset = outer_offset * num_inner_steps + (inner_rank - inner_step) % num_inner_steps + + dq, step_dk, step_dv = block_bar_attn_bwd( + dout, q, k, v, out, dq, None, None, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + deterministic=False, + step=offset, + causal=block_causal, + ) + + # Update dQ, dK, dV + if inner_step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + inner_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + if inner_step + 1 != num_inner_steps: + inner_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = inner_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + inner_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +def minfer_dr_stripe_triton_backward_outer( + process_group: dist.ProcessGroup, + outer_ring: List[int], + inner_ring: List[int], + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + outer_kv_comm = RingComm(process_group, False, outer_ring) + outer_d_kv_comm = RingComm(process_group, False, outer_ring) + outer_rank = outer_ring.index(outer_kv_comm.rank) + num_outer_steps = len(outer_ring) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for outer_step in range(num_outer_steps): + if outer_step + 1 != num_outer_steps: + next_k, next_v = outer_kv_comm.send_recv_kv(k, v) + + outer_offset = (outer_rank - outer_step) % num_outer_steps + step_dq, step_dk, step_dv = minfer_dr_stripe_triton_backward_inner( + process_group, outer_step, outer_offset, inner_ring, + dout, q, k, v, out, softmax_lse, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, granularity, + ) + + if outer_step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + dq += step_dq + outer_d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + if outer_step + 1 != num_outer_steps: + outer_kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = outer_d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + outer_d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +# ----------------------------------------------------------- +# Attention Classes +class MInferDRStripeFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: softmax_scale = head_dim ** (-0.5) + inner_ring = get_inner_ring(group) + outer_ring = get_outer_ring(group) + + # ---------------------------------------------- + # Index Build + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + granularity=granularity, group=group + ) + + # ---------------------------------------------- + # Shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # ---------------------------------------------- + # Compute + out, softmax_lse, bar_k, bar_v = minfer_dr_stripe_forward_outer( + group, outer_ring, inner_ring, + q, k, v, layer_idx, + softmax_scale, + block_mask, bar_idx, bar_cnt, v_idx, v_cnt, + granularity, + ) + + # ---------------------------------------------- + # Recover + recovered_out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + + # ---------------------------------------------- + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.inner_ring = inner_ring + ctx.outer_ring = outer_ring + ctx.layer_idx = layer_idx + + # ---------------------------------------------- + # Output and Return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v = ctx.saved_tensors + inner_ring = ctx.inner_ring + layer_idx = ctx.layer_idx + group = ctx.group + + # ---------------------------------------------- + # Shuffle + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = minfer_dr_stripe_backward_outer( + ctx.group, ctx.outer_ring, ctx.inner_ring, + dout, q, k, v, out, softmax_lse, + layer_idx, ctx.softmax_scale, + block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v, + granularity=ctx.granularity, + ) + + # ---------------------------------------------- + # Recover + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + return dq, dk, dv, None, None, None, None, None, None, None + +class MInferDRStripeTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + # build index TODO: move convert_indices() into the first step + block_mask, bar_idx, bar_cnt, _, _, _ = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + # TODO: remove shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + inner_ring = get_inner_ring(group) + outer_ring = get_outer_ring(group) + out, softmax_lse = minfer_dr_stripe_triton_forward_outer( + group, outer_ring, inner_ring, + q, k, v, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, granularity, + ) + + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.inner_ring = inner_ring + ctx.outer_ring = outer_ring + ctx.layer_idx = layer_idx + + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + dq, dk, dv = minfer_dr_stripe_triton_backward_outer( + ctx.group, ctx.outer_ring, ctx.inner_ring, + dout, q, k, v, out, softmax_lse, ctx.softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, ctx.granularity, + ) + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None + +# --------------------------------------------------------------------- +# Wrapped Attention Functions +def minfer_dr_stripe_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + if not use_triton(): + return MInferDRStripeFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + else: + print(f"Rank {dist.get_rank()} | minfer_dr_stripe_func | Using Triton implementation for MTraining") + return MInferDRStripeTritonFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_dr_stripe_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs, +): + return minfer_dr_stripe_func( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + *args, **kwargs + ) + +def minfer_dr_stripe_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + *args, **kwargs, +): + return minfer_dr_stripe_func( + q, + kv[:, :, 0], + kv[:, :, 1], + *args, **kwargs + ) \ No newline at end of file diff --git a/minference/dist_ops/minfer_striped.py b/minference/dist_ops/minfer_striped.py new file mode 100644 index 0000000..9225d24 --- /dev/null +++ b/minference/dist_ops/minfer_striped.py @@ -0,0 +1,486 @@ +import os +import sys +import torch +import triton +import torch.distributed as dist +from typing import List, Tuple, Dict + +from minference.ops.utils import use_triton +from minference.dist_ops.utils import ( + RingComm, shuffle_striped_input, recover_striped_output, +) +from minference.ops.pit_sparse_flash_attention_v3 import ( + block_bar_attn_fwd, block_attn_bwd, bar_attn_bwd, block_bar_attn_bwd +) +from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask, extract_kv, merge_kv + + + +if torch.version.hip is None: + original_flags = sys.getdlopenflags() + try: + sys.setdlopenflags(os.RTLD_LAZY | os.RTLD_GLOBAL) + import block_sparse_attn_cuda # type: ignore + from block_sparse_attn.block_sparse_attn_interface import convert_blockmask_row_reverse, convert_blockmask_col_reverse # type: ignore + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + except ModuleNotFoundError as e: + print(f"[Warning] Failed to import block_sparse_attn_cuda: {e}") + finally: + # Restore original flags for future imports + sys.setdlopenflags(original_flags) + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + +# ------------------------------------------------------------------ +# CUDA-based Implementation +def minfer_stripe_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group, zigzag=False) + + out, lse = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + batch_size, _, num_qo_heads, head_dim = q.shape + max_v_size = v_idx.shape[-1] + bar_k = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + bar_v = torch.empty((batch_size, max_v_size, num_qo_heads, head_dim), dtype=q.dtype, device=q.device) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (comm.rank - step) % comm.world_size + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + extract_kv(k, v, bar_k, bar_v, v_idx, v_cnt, step=offset) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + # lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse, bar_k, bar_v + +def minfer_stripe_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_pos: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + granularity: int = 128, +): + kv_comm = RingComm(process_group, zigzag=False) + d_kv_comm = RingComm(process_group, zigzag=False) + + dq, dk, dv = None, None, None + next_k, next_v = None, None + next_dk, next_dv = None, None + dk_comm_buffer, dv_comm_buffer = None, None + block_mask = convert_blockmask_col_reverse(block_mask, causal=True) + + # Bar Mask + full_bar_cnt = torch.stack([bar_cnt[..., 0], bar_cnt[..., -1]], dim=-1) + dq, bar_dk, bar_dv = bar_attn_bwd( + dout, q, bar_k, bar_v, out, None, None, None, + softmax_lse, softmax_scale, + bar_pos, full_bar_cnt, + granularity=granularity, + deterministic=False, + step=0, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (kv_comm.rank - step) % kv_comm.world_size + + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + converted=True, + ) + + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + dq += step_dq + merge_kv(dk, dv, bar_dk, bar_dv, v_idx, v_cnt, step=offset) + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +# ------------------------------------------------------------------ +# Triton-based Implementation +def minfer_stripe_triton_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group) + out, lse = None, None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (comm.rank - step) % comm.world_size + + + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + return out, lse + + +def minfer_stripe_triton_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_idx: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (kv_comm.rank - step) % kv_comm.world_size + + dq, step_dk, step_dv = block_bar_attn_bwd( + dout, q, k, v, out, dq, None, None, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + deterministic=False, + step=offset, + causal=block_causal, + ) + + # Update dQ, dK, dV + if step == 0: + dk = step_dk + dv = step_dv + else: + d_kv_comm.wait() + + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + dk += step_dk + dv += step_dv + + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +# ------------------------------------------------------------------ +# Attention Classes +class MInferStripeFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + + # Indexing + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + granularity=granularity, group=group + ) + + # Shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # Compute + out, softmax_lse, bar_k, bar_v = minfer_stripe_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, v_idx, v_cnt, + granularity=granularity, + ) + + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # Recover outputs + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + + # Output and Return + if return_softmax: + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v = ctx.saved_tensors + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + # Shuffle + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=granularity, process_group=group) + + # Compute + dq, dk, dv = minfer_stripe_backward( + group, dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, bar_pos, bar_cnt, v_idx, v_cnt, bar_k, bar_v, + granularity=granularity, + ) + + # Recover + dq = recover_striped_output(dq, dim=1, granularity=granularity, process_group=group) + dk = recover_striped_output(dk, dim=1, granularity=granularity, process_group=group) + dv = recover_striped_output(dv, dim=1, granularity=granularity, process_group=group) + return dq, dk, dv, None, None, None, None, None, None, None + +class MInferStripeTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + # built block_idx: [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + block_mask, bar_idx, bar_cnt, _, _, _ = build_index(q, k, v_size, s_size, num_tokens_local, granularity=granularity, group=group) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + + # slash attn + out, softmax_lse = minfer_stripe_triton_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity=granularity, + ) + + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + return (out, softmax_lse, None) + return out + + @staticmethod + def backward(ctx, dout, *args): + layer_idx = ctx.layer_idx + dout = shuffle_striped_input(to_send=dout, dim=1, granularity=ctx.granularity, process_group=ctx.group) + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + dq, dk, dv = minfer_stripe_triton_backward( + ctx.group, dout, q, k, v, out, softmax_lse, + layer_idx, ctx.softmax_scale, + block_idx, block_cnt, bar_idx, bar_cnt, + granularity=ctx.granularity, + ) + + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + + return dq, dk, dv, None, None, None, None, None, None, None + + +# ------------------------------------------------------------------ +# Wrapped Attention Functions +# ------------------ +# CUDA-based +def minfer_stripe_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + if not use_triton(): + return MInferStripeFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + else: + print(f"Rank {dist.get_rank()} | minfer_stripe_func | using Triton implementation for MTraining w. Striped Ring Attention") + return MInferStripeTritonFunc.apply( + q, k, v, + v_size, s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + +def minfer_stripe_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs +): + return minfer_stripe_func( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + *args, **kwargs + ) + + +def minfer_stripe_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + *args, **kwargs +): + return minfer_stripe_func( + q, + kv[:, :, 0], + kv[:, :, 1], + *args, **kwargs + ) diff --git a/minference/dist_ops/minfer_zigzag.py b/minference/dist_ops/minfer_zigzag.py new file mode 100644 index 0000000..e8b1d79 --- /dev/null +++ b/minference/dist_ops/minfer_zigzag.py @@ -0,0 +1,329 @@ +import os +import torch +import triton +import torch.distributed as dist +from typing import List, Tuple, Dict + +from .utils import ( + RingComm, shuffle_zigzag_input, recover_zigzag_output, +) +from minference.ops.op_utils.vertical_slash_utils import build_index, convert_blockmask +from minference.ops.pit_sparse_flash_attention_v3 import block_bar_attn_fwd, block_attn_bwd, bar_attn_bwd + +def minfer_zigzag_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + comm = RingComm(process_group, zigzag=True) + ring_list = comm.ring_list + ring_index = ring_list.index(comm.rank) + + out, lse = None, None + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (ring_index - step) % comm.world_size + + # ---------------------------------------------- + out, lse = block_bar_attn_fwd( + q, k, v, out, lse, softmax_scale, + bar_idx, bar_cnt, block_idx[offset], block_cnt[offset], + granularity=granularity, + step=offset, + causal=block_causal, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + return out, lse + + +def minfer_zigzag_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int = 128, +): + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + ring_list = kv_comm.ring_list + ring_index = ring_list.index(kv_comm.rank) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + block_causal = step == 0 + offset = (ring_index - step) % kv_comm.world_size + + # ---------------------------------------------- + # Block Mask + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask[offset], + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # ---------------------------------------------- + # Bar Mask + step_dq, step_dk, step_dv = bar_attn_bwd( + dout, q, k, v, out, step_dq, step_dk, step_dv, + softmax_lse, softmax_scale, + bar_idx, bar_cnt, + granularity=granularity, + deterministic=False, + step=offset, + ) + + # ---------------------------------------------- + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dq += step_dq + dk += step_dk + dv += step_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class MInferZigzagAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_softmax, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + batch_size, num_tokens_local, num_qo_heads, head_dim = q.shape + + # ------------------------------------------------------------------ + # Index Build + block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt = build_index( + q, k, v_size, s_size, num_tokens_local, + stripe_transform=False, + zigzag_transform=True, + granularity=granularity, group=group + ) + + # ---------------------------------------------- + # Shuffle + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + out, softmax_lse = minfer_zigzag_forward( + group, q, k, v, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, + granularity=granularity, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + + # ---------------------------------------------- + # Saving tensors for backward + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # Output and Return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + # ---------------------------------------------- + # Shuffle + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = minfer_zigzag_backward( + group, dout, q, k, v, out, softmax_lse, + layer_idx, softmax_scale, + block_mask, bar_idx, bar_cnt, + granularity=granularity, + ) + + # ---------------------------------------------- + # Recover + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + + return dq, dk, dv, None, None, None, None, None, None, None + + +def minfer_zigzag_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferZigzagAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_zigzag_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return MInferZigzagAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) + + +def minfer_zigzag_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v_size: List[int], # [num_heads] + s_size: List[int], # [num_heads] + layer_idx: int = 0, + dropout_p: float = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[int, int] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return MInferZigzagAttnFunc.apply( + q, + k, + v, + v_size, + s_size, + layer_idx, + softmax_scale, + granularity, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/moba_zigzag.py b/minference/dist_ops/moba_zigzag.py new file mode 100644 index 0000000..3f607e0 --- /dev/null +++ b/minference/dist_ops/moba_zigzag.py @@ -0,0 +1,1114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import torch +import torch.distributed as dist + +from einops import rearrange +from typing import List, Tuple, Dict + +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) + +from .utils import ( + RingComm, update_out_and_lse, + recover_zigzag_output, get_default_args, +) +from minference.ops.op_utils.moba_utils import ( + shuffle_input_all, shuffle_input_only, compute_moba_gate, + tensor_4d_to_3d +) + + +def moba_zigzag_attn_fwd_step( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # [S, H, D] + step: int, + causal: bool, + # q_seq_offsets: torch.Tensor, + num_q_blocks: int, + k_seq_offsets: torch.Tensor, + + gate_mask: torch.Tensor, # [num_filtered_chunk, num_head, seq_len] + cu_chunk: torch.Tensor, + filtered_chunk_indices: torch.Tensor, + num_filtered_chunk: int, + chunk_to_batch: torch.Tensor, + moba_chunk_size: int, + moba_topk: int, + + softmax_scale, + dropout_p=0, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +) -> Tuple[torch.Tensor, torch.Tensor, int]: + _, _, seq_len = gate_mask.shape + q_block_seq_len, num_head, head_dim = q.shape + k_block_seq_len, k_num_head, _ = k.shape + if num_head > k_num_head: + k = torch.repeat_interleave(k, num_head // k_num_head, dim=1) + v = torch.repeat_interleave(v, num_head // k_num_head, dim=1) + + block_seq_len = q_block_seq_len // num_q_blocks + + # assumption: block_seq_len is divisible by moba_chunk_size + assert (block_seq_len % moba_chunk_size == 0), "block_seq_len should be divisible by moba_chunk_size" + + kv = torch.stack((k, v), dim=1) + k_seq_offset_list = [k_seq_offsets[i].detach().cpu().item() for i in range(len(k_seq_offsets))] + filtered_kv_indices = torch.arange( + 0, min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[0], + device=k.device, dtype=torch.int32 + ) + kv_chunk_indices = torch.arange( + k_seq_offset_list[0], min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, device=k.device, dtype=torch.int32 + ) + if len(k_seq_offset_list) > 1: + filtered_kv_indices = torch.cat([ + filtered_kv_indices, + torch.arange( + block_seq_len, + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[1] + block_seq_len, + device=k.device, dtype=torch.int32 + ) + ]) + kv_chunk_indices = torch.cat([ + kv_chunk_indices, + torch.arange( + k_seq_offset_list[1], + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, + device=k.device, dtype=torch.int32 + ) + ]) + filtered_kv = kv.index_select(0, filtered_kv_indices) + kv_chunk_indices = kv_chunk_indices // moba_chunk_size + num_filtered_kv_chunks = len(kv_chunk_indices) + + q_indices = torch.arange( + 0 if num_q_blocks == 2 else block_seq_len, 2 * block_seq_len, + device=q.device, dtype=torch.int32 + ) + + # varlen trick: combining all q index that needs moba attn + # the result will be like [ C0H0 ][ C0H1 ][ C0H2 ][ ... ][ CnHm ] + gate_mask_q = gate_mask.index_select(0, kv_chunk_indices) + gate_mask_q = gate_mask_q.index_select(2, q_indices) # we need to know which part(s) of the two query blocks should be activated + + moba_q_indices = gate_mask_q.reshape(gate_mask_q.shape[0], -1).nonzero(as_tuple=True)[-1] + moba_seqlen_q = gate_mask_q.sum(dim=-1).flatten() + + # ----------------------------------------------------------- + # select all q that needs moba attn based on the moba_q_indices + moba_q = rearrange(q, "s h d -> ( h s ) d") + moba_q = moba_q.index_select(0, moba_q_indices) # [ selected_HS, D ] + moba_q = moba_q.unsqueeze(1) + + # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q + moba_q_sh_indices = moba_q_indices % q_block_seq_len * num_head + moba_q_indices // q_block_seq_len + + """ prepare moba kv """ + # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q + + # cut off zero experts + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + zero_expert_count = q_zero_mask.sum() + # only keep the kv that has q select > 0 + if zero_expert_count > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + + # moba cu_seqlen for flash attn + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + # ----------------------------------------------------------------------------------- + # here `x` only stands for a dimension (stack dimension for KV) + moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # [H, K_S, 2, D ] + + moba_kv = moba_kv.split(moba_chunk_size, dim=1) # tuple of (num_selected_chunks) elements with shape [H, chunk_size, 2, D] + moba_kv = torch.cat(moba_kv, dim=0) # [H x num_selected_chunks, chunk_size, 2, D ] after split + + # The transformation is aimed for masking out by valid_expert_mask where the mask selects elements along (H x num_selected_chunks) dimension + if zero_expert_count > 0: + assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count + moba_kv = moba_kv[ + valid_expert_mask + ] # cut off zero Q expert from kv , or the grad may be nan + + moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2) # [H x num_selected_chunks x chunk_size, 2, 1, D] + moba_cu_seqlen_kv = ( + torch.arange( + 0, num_filtered_kv_chunks * num_head + 1 - zero_expert_count, + dtype=torch.int32, device=q.device, + ) * moba_chunk_size + ) + + # Shape check + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"moba_cu_seqlen_kv.shape != moba_cu_seqlen_q.shape {moba_cu_seqlen_kv.shape} != {moba_cu_seqlen_q.shape}" + + softmax_scale = softmax_scale = head_dim ** (-0.5) + + self_attn_cu_seqlen = [0] + [moba_chunk_size] * (q_block_seq_len // moba_chunk_size) + if q_block_seq_len % moba_chunk_size != 0: + self_attn_cu_seqlen.append(q_block_seq_len % moba_chunk_size) + self_attn_cu_seqlen = torch.tensor(self_attn_cu_seqlen, device=q.device, dtype=torch.int32) + self_attn_cu_seqlen = self_attn_cu_seqlen.cumsum(dim=0, dtype=torch.int32) + + # ----------------------------------------------------------------------------------- + # self attn + if causal: + # out, softmax_lse, S_dmask, rng_state + self_attn_out_sh, self_attn_lse_hs, _, _ = ( + _flash_attn_varlen_forward( + q=q, k=k, v=v, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=q_block_seq_len, + max_seqlen_k=k_block_seq_len, + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + ) + ) + else: + # self_attn_out_sh, self_attn_lse_hs = None, None + self_attn_out_sh = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + self_attn_lse_hs = torch.zeros((num_head, q_block_seq_len), device=q.device, dtype=torch.float32) + (-float('inf')) + + + # moba attn + # moba_attn_lse_hs - [1, num_nonzero_elems] + if moba_q.shape[0] > 0: + # out, softmax_lse, S_dmask, rng_state + moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=q_block_seq_len, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + else: + moba_attn_lse_hs = torch.zeros((1, moba_q.shape[0]), device=q.device, dtype=torch.float32) + (-float('inf')) + + # ----------------------------------------------------------------------------------- + # If no queries need to be computed with the current KV chunk and no causal attention is needed, return None to skip the output update + if not causal and moba_q.shape[0] == 0: + return None, None + + # ----------------------------------------------------------------------------------- + # Processing output and lse + # output buffer [S, H, D], same shape as q + output = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + # flatten vS & H for index ops + output_2d = output.view(-1, q.shape[2]) + + # -------------------------------------------------- + moba_attn_lse: torch.Tensor = moba_attn_lse_hs.t().contiguous() # [ num_nonzero_elems, 1 ] + self_attn_lse_sh = self_attn_lse_hs.t().contiguous() # [q_S, H] + + # calc mixed_lse + # minus max lse to avoid exp explosion + max_lse_1d = self_attn_lse_sh.view(-1) # [ vS ] + max_lse_1d = max_lse_1d.index_reduce( + 0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax" + ) + self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh) + moba_attn_lse = ( + moba_attn_lse.view(-1) + .sub(max_lse_1d.index_select(0, moba_q_sh_indices)) + .reshape_as(moba_attn_lse) + ) + + # -------------------------------------------------- + # Build mixed attn lse + mixed_attn_se_sh = self_attn_lse_sh.exp() if causal else torch.zeros_like(self_attn_lse_sh) + moba_attn_se = moba_attn_lse.exp() if moba_q.shape[0] > 0 else torch.zeros_like(moba_attn_lse) + + # index_add_: converting elements from 1D tensor (num_nonzero_elems) to matrices (HS) + # Now, mixed_attn_se_sh is the sum of LSE of self attn and LSE of moba attn (including multiple LSEs corresponding to the same q token but in different HS positions) + mixed_attn_se_sh.view(-1).index_add_( + 0, moba_q_sh_indices, moba_attn_se.view(-1) + ) + mixed_attn_lse_sh = mixed_attn_se_sh.log() + + # ---------------------------------------------------- + # Compute factor of self-attention and add to output + if causal: + factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [ vS, H ] + self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1) + output_2d += self_attn_out_sh.reshape_as(output_2d) + + # add moba output + # ---------------------------------------------------- + # Compute factor of moba-attention and add to output + if moba_q.shape[0] > 0: + mixed_attn_lse = ( + mixed_attn_lse_sh.view(-1) + .index_select(0, moba_q_sh_indices) + .view_as(moba_attn_lse) + ) + factor = (moba_attn_lse - mixed_attn_lse).exp() # [ vS, H ] + moba_attn_out = moba_attn_out * factor.unsqueeze(-1) + + raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1]) + output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out.to(output_2d.dtype)) + + output = output.to(q.dtype) + + # add back max lse + mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh) + return output, mixed_attn_lse_sh.t() + +def moba_zigzag_attn_fwd( + process_group, + q: torch.Tensor, # [S, H, D] + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, # sequence offsets for Q + layer_idx: int, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group, zigzag=True) + + block_seq_len = q.shape[0] // 2 + seq_len, num_q_heads, head_dim = q.shape + + out, lse = None, None + next_k, next_v = None, None + + kv_seq_offsets = torch.clone(seq_offsets) + next_kv_seq_offsets = None + + def fwd_step( + q_, k_, v_, step_, causal_, + # q_seq_offsets, + num_q_blocks, + k_seq_offsets + ): + return moba_zigzag_attn_fwd_step( + q_, k_, v_, + step_, + causal_, + num_q_blocks, + k_seq_offsets, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale, + dropout_p, + window_size, + alibi_slopes, + deterministic, + ) + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + # when step < N-1, do the ring-communication to get KV to be used in the next round + next_k, next_v, next_kv_seq_offsets = comm.send_recv_kv_offsets(k, v, kv_seq_offsets) + + if step == 0: + # Do softmax(QK^T / sqrt(d_k))V on the currently hold K and V + # and record the output and the LSE + block_out, block_lse = fwd_step( + q, k, v, step, causal_=True, + num_q_blocks=2, + k_seq_offsets=kv_seq_offsets, + ) + + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, + use_triton_kernel=False, + ) + elif step <= comm.revert_rank: + k0 = k[:block_seq_len] + v0 = v[:block_seq_len] + block_out, block_lse = fwd_step( + q, k0, v0, step, causal_=False, + num_q_blocks=2, + k_seq_offsets=kv_seq_offsets[0:1], + ) + + if block_out is not None: + out, lse = update_out_and_lse( + out, lse, block_out, block_lse, + use_triton_kernel=False, + ) + else: + q1 = q[block_seq_len:] + block_out, block_lse = fwd_step( + q1, k, v, step, causal_=False, + num_q_blocks=1, + k_seq_offsets=kv_seq_offsets, + ) + + if block_out is not None: + out, lse = update_out_and_lse( + out, lse, + block_out, + block_lse, + slice_=(slice(block_seq_len, None)), + use_triton_kernel=False, + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v, kv_seq_offsets = next_k, next_v, next_kv_seq_offsets + + out = out.to(q.dtype) # [S, H, D] + lse = lse.squeeze(dim=-1).transpose(0, 1) # [H, S] + return out, lse + +def moba_zigzag_attn_bwd_step( + step: int, + + dout, # [blk_S, H, D] + out, # [blk_S, H, D] + causal: bool, + + q: torch.Tensor, # [blk_S, H, D] + k: torch.Tensor, # [blk_S, H, D] + v: torch.Tensor, # [blk_S, H, D] + + softmax_lse: torch.Tensor, # [H, blk_S] + num_q_blocks: int, + k_seq_offsets: torch.Tensor, + layer_idx: int, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch: torch.Tensor, + moba_chunk_size: int, + moba_topk: int, + + softmax_scale, + dropout_p=0, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +) -> Tuple[torch.Tensor, torch.Tensor, int]: + _, _, seq_len = gate_mask.shape + q_block_seq_len, num_head, head_dim = q.shape + k_block_seq_len, k_num_head, _ = k.shape + if num_head > k_num_head: + k = torch.repeat_interleave(k, num_head // k_num_head, dim=1) + v = torch.repeat_interleave(v, num_head // k_num_head, dim=1) + + block_seq_len = q_block_seq_len // num_q_blocks + + # assumption: block_seq_len is divisible by moba_chunk_size + assert (block_seq_len % moba_chunk_size == 0), "block_seq_len should be divisible by moba_chunk_size" + + # ----------------------------------------------------------------------------------- + dq = torch.zeros_like(q, dtype=q.dtype) + dk = torch.zeros_like(k, dtype=k.dtype) + dv = torch.zeros_like(v, dtype=v.dtype) + + kv = torch.stack((k, v), dim=1) + dkv = torch.stack((dk, dv), dim=1) + # ----------------------------------------------------------------------------------- + + + k_seq_offset_list = [k_seq_offsets[i].detach().cpu().item() for i in range(len(k_seq_offsets))] + filtered_kv_indices = torch.arange( + 0, min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[0], + device=k.device, dtype=torch.int32 + ) + kv_chunk_indices = torch.arange( + k_seq_offset_list[0], min(k_seq_offset_list[0] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, device=k.device, dtype=torch.int32 + ) + if len(k_seq_offset_list) > 1: + filtered_kv_indices = torch.cat([ + filtered_kv_indices, + torch.arange( + block_seq_len, + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size) - k_seq_offset_list[1] + block_seq_len, + device=k.device, dtype=torch.int32 + ) + ]) + kv_chunk_indices = torch.cat([ + kv_chunk_indices, + torch.arange( + k_seq_offset_list[1], + min(k_seq_offset_list[1] + block_seq_len, num_filtered_chunk * moba_chunk_size), + moba_chunk_size, + device=k.device, dtype=torch.int32 + ) + ]) + filtered_kv = kv.index_select(0, filtered_kv_indices) + filtered_dkv = dkv.index_select(0, filtered_kv_indices) + + kv_chunk_indices = kv_chunk_indices // moba_chunk_size + num_filtered_kv_chunks = len(kv_chunk_indices) + + q_indices = torch.arange( + 0 if num_q_blocks == 2 else block_seq_len, 2 * block_seq_len, + device=q.device, dtype=torch.int32 + ) + + # varlen trick: combining all q index that needs moba attn + # the result will be like [ C0H0 ][ C0H1 ][ C0H2 ][ ... ][ CnHm ] + gate_mask_q = gate_mask.index_select(0, kv_chunk_indices) + gate_mask_q = gate_mask_q.index_select(2, q_indices) # we need to know which part(s) of the two query blocks should be activated + + # equivalent to einops.rearrange(q, "n h s -> n (h s)"). ([s] [s] ... [s] for h times) + # [HS indices] * N (total size: all non-zero elements in HS dimension, potentially repeat) + # [num_selected_chunks, HS indices of non-zero elements] + # gate_mask has been filtered by q_indices. If we still need use gate_mask_q for indexing, it should be offset by block_seq_len if num_q_blocks == 1 + # + (0 if num_q_blocks == 2 else block_seq_len) + moba_q_indices = gate_mask_q.reshape(gate_mask_q.shape[0], -1).nonzero(as_tuple=True)[-1] + + # moba_seqlen_q indicates that how many q chunks are selected for each kv chunk - head + # moba_seqlen_q has shape (num_selecte_chunks * num_heads, ) => varlen_forward computes attention by (num_selecte_chunks * num_heads) times + moba_seqlen_q = gate_mask_q.sum(dim=-1).flatten() + + # ----------------------------------------------------------- + # select all q that needs moba attn based on the moba_q_indices + moba_q = rearrange(q, "s h d -> ( h s ) d") + moba_q = moba_q.index_select(0, moba_q_indices) # [ selected_HS, D ] + + # [ selected_S, 1, D ] (pseudo head dim for flash attn) + moba_q = moba_q.unsqueeze(1) + + # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q + # note that original q has shape (S, H, D) while moba_q_indices is based on (H S) + # Ignoring D, q has the flattend form like [H] [H] ... [H] for S times + # => moba_q_sh_indices is the index of each token in the original q tensor + moba_q_sh_indices = moba_q_indices % q_block_seq_len * num_head + moba_q_indices // q_block_seq_len + + + """ prepare moba kv """ + # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q + # cut off zero experts + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + zero_expert_count = q_zero_mask.sum() + # only keep the kv that has q select > 0 + if zero_expert_count > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + + # moba cu_seqlen for flash attn + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + + # ------------------------------ + # Select dout and output + d_moba_out = ( + # [num_non-zero_elements, D] + dout.view(-1, head_dim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + moba_out = ( + # [num_non-zero_elements, D] + out.view(-1, head_dim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + + + + # ----------------------------------------------------------------------------------- + # here `x` only stands for a dimension (stack dimension for KV) + moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # [H, K_S, 2, D ] + moba_dkv = rearrange(filtered_dkv, "s x h d -> h s x d") # [H, K_S, 2, D ] + + moba_kv = moba_kv.split(moba_chunk_size, dim=1) # tuple of (num_selected_chunks) elements with shape [H, chunk_size, 2, D] + moba_kv = torch.cat(moba_kv, dim=0) # [H x num_selected_chunks, chunk_size, 2, D ] after split + moba_dkv = torch.cat(moba_dkv.split(moba_chunk_size, dim=1), dim=0) # [H x num_selected_chunks, chunk_size, 2, D ] after split + + # The transformation is aimed for masking out by valid_expert_mask where the mask selects elements along (H x num_selected_chunks) dimension + if zero_expert_count > 0: + assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count + + # cut off zero Q expert from kv , or the grad may be nan + moba_kv = moba_kv[valid_expert_mask] + moba_dkv = moba_dkv[valid_expert_mask] + + moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2) # [H x num_selected_chunks x chunk_size, 2, 1, D] + moba_dkv = moba_dkv.flatten(start_dim=0, end_dim=1).unsqueeze(2) # [H x num_selected_chunks x chunk_size, 2, 1, D] + + moba_cu_seqlen_kv = ( + torch.arange( + 0, num_filtered_kv_chunks * num_head + 1 - zero_expert_count, + dtype=torch.int32, device=q.device, + ) * moba_chunk_size + ) + + # Shape check + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"moba_cu_seqlen_kv.shape != moba_cu_seqlen_q.shape {moba_cu_seqlen_kv.shape} != {moba_cu_seqlen_q.shape}" + + + self_attn_cu_seqlen = [0] + [moba_chunk_size] * (q_block_seq_len // moba_chunk_size) + if q_block_seq_len % moba_chunk_size != 0: + self_attn_cu_seqlen.append(q_block_seq_len % moba_chunk_size) + self_attn_cu_seqlen = torch.tensor(self_attn_cu_seqlen, device=q.device, dtype=torch.int32) + self_attn_cu_seqlen = self_attn_cu_seqlen.cumsum(dim=0, dtype=torch.int32) + + # ----------------------------------------------------------------------------------- + # self attn + if causal: + dq_, dk_, dv_ = torch.empty_like(dq), torch.empty_like(dkv[:, 0]), torch.empty_like(dkv[:, 1]) + _flash_attn_varlen_backward( + dout=dout, out=out, + q=q, k=k, v=v, + dq=dq_, dk=dk_, dv=dv_, + softmax_lse=softmax_lse.contiguous(), + + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=q_block_seq_len, + max_seqlen_k=k_block_seq_len, + + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + deterministic=deterministic, + softcap=0.0, + ) + dq, dkv[:, 0], dkv[:, 1] = dq + dq_, dk_ + dkv[:, 0], dv_ + dkv[:, 1] + + if moba_q.shape[0] > 0: + softmax_lse_sh = rearrange(softmax_lse.contiguous(), "h s -> (s h)") + moba_attn_lse = ( + # [1, num_non-zero_elements] + softmax_lse_sh.index_select(0, moba_q_sh_indices).view(1, -1) + ) + + moba_dq_, moba_dk_, moba_dv_ = torch.empty_like(moba_q), torch.empty_like(moba_kv[:, 0]), torch.empty_like(moba_kv[:, 1]) + _flash_attn_varlen_backward( + dout=d_moba_out, out=moba_out, + q=moba_q, k=moba_kv[:, 0], v=moba_kv[:, 1], + dq=moba_dq_, dk=moba_dk_, dv=moba_dv_, + softmax_lse=moba_attn_lse, + + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + + max_seqlen_q=q_block_seq_len, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + + causal=False, + dropout_p=0.0, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + deterministic=deterministic, + softcap=0.0, + ) + + dq.view(-1, q.shape[-1]).index_add_( + 0, moba_q_sh_indices, + moba_dq_.view(-1, head_dim).to(dq.dtype) + ) + moba_dkv[:, 0] = moba_dkv[:, 0] + moba_dk_ + moba_dkv[:, 1] = moba_dkv[:, 1] + moba_dv_ + + # ------------------------------------------------------------------------------------ + # Backpropagate moba_dkv to dk and dv + moba_dkv = moba_dkv.squeeze(2) # [H x num_selected_chunks x chunk_size, 2, D] + moba_dkv = moba_dkv.unflatten(0, (-1, moba_chunk_size)) # [H x num_selected_chunks, chunk_size, 2, D] + + if zero_expert_count > 0: + full_moba_dkv = torch.zeros( + (moba_dkv.shape[0] + zero_expert_count, moba_chunk_size, 2, head_dim), + dtype=moba_dkv.dtype, device=moba_dkv.device + ) + full_moba_dkv[valid_expert_mask] = moba_dkv + moba_dkv = full_moba_dkv # [H x num_selected_chunks, chunk_size, 2, D] + moba_dkv = moba_dkv.split(num_head, dim=0) # [H, num_selected_chunks, chunk_size, 2, D] + moba_dkv = torch.cat(moba_dkv, dim=1) # [H, num_selected_chunks x chunk_size, 2, D] + + filtered_dkv = rearrange(moba_dkv, "h s x d -> s x h d") + dkv.index_add_( + 0, filtered_kv_indices, filtered_dkv # [K_S, 2, H, D] + ) + + if num_head > k_num_head: + num_kv_replicas = num_head // k_num_head + dkv_reshaped = dkv.view(-1, 2, k_num_head, num_kv_replicas, head_dim) + dkv = dkv_reshaped.sum(dim=3) + + return dq, dkv[:, 0], dkv[:, 1] + +def moba_zigzag_attn_bwd( + process_group, + dout, # [blk_S, H, D] + q: torch.Tensor, # [blk_S, H, D] + k: torch.Tensor, # [blk_S, H, D] + v: torch.Tensor, # [blk_S, H, D] + out, # [blk_S, H, D] + softmax_lse, # [H, blk_S] + + seq_offsets: torch.Tensor, # sequence offsets for Q + layer_idx: int, + + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + + kv_seq_offsets = torch.clone(seq_offsets) + seq_len, num_q_heads, head_dim = q.shape + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=0)[1] + q1 = q.chunk(2, dim=0)[1] + out1 = out.chunk(2, dim=0)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=1)[1].contiguous() + block_seq_len = q.shape[0] // 2 + + def backward( + step, + dout_, q_, k_, v_, out_, + k_seq_offsets, + softmax_lse_, + causal + ): + seqlen_q = q_.shape[0] + seqlen_kv = k_.shape[0] + + params = get_default_args(moba_zigzag_attn_bwd_step).copy() + params.update( + { + "step": step, + "causal": causal, + "dout": dout_, + "out": out_, + + "q": q_, + "k": k_, + "v": v_, + "softmax_lse": softmax_lse_, + + "num_q_blocks": 1 if seqlen_q == block_seq_len else 2, + "k_seq_offsets": k_seq_offsets, + "layer_idx": layer_idx, + + "gate_mask": gate_mask, + "cu_chunk": cu_chunk, + "filtered_chunk_indices": filtered_chunk_indices, + "num_filtered_chunk": num_filtered_chunk, + "chunk_to_batch": chunk_to_batch, + "moba_chunk_size": moba_chunk_size, + "moba_topk": moba_topk, + + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + return moba_zigzag_attn_bwd_step(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + # next_k, next_v = kv_comm.send_recv_kv(k, v) + next_k, next_v, next_kv_seq_offsets = kv_comm.send_recv_kv_offsets(k, v, kv_seq_offsets) + + if step == 0: + dq_buffer, dk_buffer, dv_buffer = backward( + step, + dout, q, k, v, out, + kv_seq_offsets, softmax_lse, causal=True + ) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:block_seq_len] + v0 = v[:block_seq_len] + dq_buffer, dk_buffer, dv_buffer = backward( + step, + dout, q, k0, v0, out, + kv_seq_offsets[0:1], softmax_lse, causal=False + ) + dq += dq_buffer + else: + dq_buffer, dk_buffer, dv_buffer = backward( + step, + dout1, q1, k, v, out1, + kv_seq_offsets, softmax_lse1, causal=False + ) + + # use the first half in dq_buffer. + dq[block_seq_len:] += dq_buffer + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dk[:block_seq_len] += dk_buffer + dv[:block_seq_len] += dv_buffer + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v, kv_seq_offsets = next_k, next_v, next_kv_seq_offsets + + # the finally received dk and dv will be the same as the first dk and dv (corresponding to local Q) + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class MoBAZigzagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [batch * seq_block_len, n_heads, head_dim] + k: torch.Tensor, + v: torch.Tensor, + seq_offset: torch.Tensor, + layer_idx, + dropout_p, + softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + # Note seq_len here refers to the total sequence length + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + assert seq_lens.min() == seq_lens.max(), "Current implementation of MoBA Zigzag Ring Attention does not support variable sequence lengths within a batch" + seq_len = seq_lens.detach().cpu()[0].item() # all sequences in the batch have the same length + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + # --------------------------- + # Compute gate values before shuffling + ( + gate_mask, cu_chunk, filtered_chunk_indices, + num_filtered_chunk, chunk_to_batch + ) = compute_moba_gate( + q, k, v, + seq_offset, + cu_seqlens, + moba_chunk_size, + moba_topk, + ) + + q, seq_offsets, gate_mask = shuffle_input_all( + to_send=q, gate_mask=gate_mask, seq_offset=seq_offset, + process_group=group + ) + k = shuffle_input_only(to_send=k, process_group=group) + v = shuffle_input_only(to_send=v, process_group=group) + k = k.contiguous() + v = v.contiguous() + + q_3d, k_3d, v_3d = \ + tensor_4d_to_3d(q), tensor_4d_to_3d(k), tensor_4d_to_3d(v) + + out_3d, softmax_lse = moba_zigzag_attn_fwd( + group, + q_3d, k_3d, v_3d, + seq_offsets, # sequence offsets for Q + layer_idx, + + gate_mask, cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + out = out_3d.reshape(*q.shape) + ctx.save_for_backward( + q, k, v, out, softmax_lse, seq_offsets, + gate_mask, cu_chunk, filtered_chunk_indices, + chunk_to_batch + ) + ctx.num_filtered_chunk = num_filtered_chunk + ctx.moba_chunk_size = moba_chunk_size + ctx.moba_topk = moba_topk + + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + ctx.layer_idx = layer_idx + ctx.seq_len = seq_len + + out = recover_zigzag_output(out, dim=1, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + ( + q, k, v, out, + softmax_lse, # [n_heads, seq_block_len] + seq_offsets, + gate_mask, cu_chunk, filtered_chunk_indices, + chunk_to_batch + ) = ctx.saved_tensors + + q_3d, k_3d, v_3d, out_3d = \ + tensor_4d_to_3d(q), tensor_4d_to_3d(k), tensor_4d_to_3d(v), \ + tensor_4d_to_3d(out) + + dout = shuffle_input_only(to_send=dout, process_group=ctx.group) + dout_3d = tensor_4d_to_3d(dout) + + num_filtered_chunk = ctx.num_filtered_chunk + moba_chunk_size = ctx.moba_chunk_size + moba_topk = ctx.moba_topk + + dq_3d, dk_3d, dv_3d = moba_zigzag_attn_bwd( + ctx.group, + dout_3d, + q_3d, k_3d, v_3d, + out_3d, softmax_lse, + + seq_offsets, + ctx.layer_idx, + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + moba_chunk_size, + moba_topk, + + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + + dq, dk, dv = \ + dq_3d.reshape(*q.shape), dk_3d.reshape(*k.shape), dv_3d.reshape(*v.shape) + + dq = recover_zigzag_output(dq, dim=1, process_group=ctx.group) + dk = recover_zigzag_output(dk, dim=1, process_group=ctx.group) + dv = recover_zigzag_output(dv, dim=1, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + +def moba_zigzag_qkvpacked_func( + qkv, + seq_offset: torch.Tensor, + layer_idx: int, + cu_seqlens, + moba_chunk_size, + moba_topk, + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return MoBAZigzagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + seq_offset, + layer_idx, + dropout_p, + softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + +def moba_zigzag_kvpacked_func( + q, + kv, + seq_offset: torch.Tensor, + layer_idx: int, + cu_seqlens, + moba_chunk_size, + moba_topk, + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return MoBAZigzagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + seq_offset, + layer_idx, + dropout_p, + softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def moba_zigzag_func( + q, k, v, # [batch_size, seq_block_len, n_heads, head_dim] + layer_idx: int, + global_seq_len: int, + moba_chunk_size, + moba_topk, + + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + batch_size = q.shape[0] + cu_seqlens = torch.cumsum( + torch.tensor([0] + [global_seq_len] * batch_size, device=q.device), + dim=0, + dtype=torch.int32, + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_offset = torch.arange(0, global_seq_len, global_seq_len // world_size)[rank:rank+1] + + return MoBAZigzagRingFlashAttnFunc.apply( + q, k, v, + seq_offset, + layer_idx, + dropout_p, + softmax_scale, + cu_seqlens, + moba_chunk_size, + moba_topk, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/ring_attention.py b/minference/dist_ops/ring_attention.py new file mode 100644 index 0000000..7da68f9 --- /dev/null +++ b/minference/dist_ops/ring_attention.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: replace with zhuzilin's implementation + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +from .utils import shuffle_zigzag_input, recover_zigzag_output, GlobalMemoryBuffer + + +_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + up_q = q[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + up_out, _, _, _, _, up_lse, _, _ = _flash_attn_forward( + up_q, + up_k, + up_v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + down_q = q[:, block_len:] + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_out, _, _, _, _, down_lse, _, _ = _flash_attn_forward( + down_q, + down_k, + down_v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + out = torch.cat([up_out, down_out], dim=1) + return out, up_lse, down_lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + up_lse, + down_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + dq = torch.zeros_like(q) + dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_dk") + dk_buffer.zero_() + dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_dv") + dv_buffer.zero_() + + up_q = q[:, :block_len] + up_out = out[:, :block_len] + up_dout = dout[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + _flash_attn_backward( + up_dout, + up_q, + up_k, + up_v, + up_out, + up_lse, + dq[:, :block_len], + dk_buffer[:, :(up_rank + 1) * block_len], + dv_buffer[:, :(up_rank + 1) * block_len], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + down_q = q[:, block_len:] + down_out = out[:, block_len:] + down_dout = dout[:, block_len:] + # TODO: optimize the buffer allocation + down_dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_down_dk") + down_dk_buffer.zero_() + down_dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_down_dv") + down_dv_buffer.zero_() + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + _flash_attn_backward( + down_dout, + down_q, + down_k, + down_v, + down_out, + down_lse, + dq[:, block_len:], + down_dk_buffer[:, :(down_rank + 1) * block_len], + down_dv_buffer[:, :(down_rank + 1) * block_len], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + dk_buffer.add_(down_dk_buffer) + dv_buffer.add_(down_dv_buffer) + + dim_size = list(k.size()) + dim_size[1] = dim_size[1] // world_size + dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) + dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) + dist._reduce_scatter_base(dk, dk_buffer, group=process_group) + dist._reduce_scatter_base(dv, dv_buffer, group=process_group) + + return dq, dk, dv + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, all gather k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, reduce scatter dk, dv +''' +class RingFlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + q = shuffle_zigzag_input(to_send=q, process_group=group) + world_size = dist.get_world_size(group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed._all_gather_base(k_buffer, k, group=group) + torch.distributed._all_gather_base(v_buffer, v, group=group) + + out, up_lse, down_lse = ring_flash_attn_forward( + group, + q, + k_buffer, + v_buffer, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, up_lse, down_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_zigzag_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_zigzag_input(to_send=dout, process_group=ctx.group) + q, k, v, out, up_lse, down_lse = ctx.saved_tensors + world_size = dist.get_world_size(ctx.group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed._all_gather_base(k_buffer, k, group=ctx.group) + torch.distributed._all_gather_base(v_buffer, v, group=ctx.group) + + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k_buffer, + v_buffer, + out, + up_lse, + down_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_zigzag_output(dq, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/minference/dist_ops/striped_attention.py b/minference/dist_ops/striped_attention.py new file mode 100644 index 0000000..bc0915f --- /dev/null +++ b/minference/dist_ops/striped_attention.py @@ -0,0 +1,404 @@ +import torch +import torch.distributed as dist +from typing import List, Tuple, Dict +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +from .utils import ( + RingComm, + update_out_and_lse, get_default_args, + shuffle_striped_input, recover_striped_output +) + +def stripe_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" + comm = RingComm(process_group) + bsz, seq_len, num_heads, head_dim = q.shape + + out, lse = None, None + next_k, next_v = None, None + + def forward(q_, k_, v_, causal_): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q_, + "k": k_, + "v": v_, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal_, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + shift = 1 if step > comm.rank else 0 + if shift == 0: + step_out, step_lse = forward(q, k, v, causal) + out, lse = update_out_and_lse(out, lse, step_out, step_lse) + else: + # Before the step index goes beyond the current rank, the received KV indices are not greater than those of the Q in the current rank + # After the step index goes beyond the current rank, only the KV indices before the last granularity are no greater than those of the Q after the first granularity + # this conclusion holds after the step index goes beyond the current rank (not just step index == current rank) + step_out, step_lse = forward( + q[:, granularity:], k[:, :-granularity], v[:, :-granularity], causal + ) + out, lse = update_out_and_lse( + out, lse, step_out, step_lse, slice_=(slice(None), slice(granularity, None)) + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def stripe_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + layer_idx: int, + softmax_scale, + granularity=1, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert ( + causal + ), "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + bsz, seq_len, num_heads, head_dim = q.shape + + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + + def backward( + granularity_, + ): + if granularity_ == 0: + k_, v_ = k, v + dk_, dv_ = block_dk_buffer, block_dv_buffer + else: + k_, v_ = k[:, :-granularity_], v[:, :-granularity_] + dk_, dv_ = block_dk_buffer[:, :-granularity_], block_dv_buffer[:, :-granularity_] + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout[:, granularity_:], + "q": q[:, granularity_:], + "k": k_, + "v": v_, + "out": out[:, granularity_:], + "softmax_lse": softmax_lse[:, :, granularity_:].contiguous(), + "dq": block_dq_buffer[:, granularity_:], + "dk": dk_, + "dv": dv_, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + params.update({"rng_state": torch.zeros((2, ), dtype=torch.int64, device=q.device)}) + _flash_attn_backward(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + shift_causal = 1 if step > kv_comm.rank else 0 + if shift_causal == 0: + backward(granularity_=0) + else: + backward(granularity_=granularity) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if shift_causal == 0: + dq += block_dq_buffer + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dq[:, granularity:] += block_dq_buffer[:, granularity:] + dk[:, :-granularity] += block_dk_buffer[:, :-granularity] + dv[:, :-granularity] += block_dv_buffer[:, :-granularity] + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class StripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + # ----------------------------------------- + # Shuffle + q = shuffle_striped_input(to_send=q, dim=1, granularity=granularity, process_group=group) + k = shuffle_striped_input(to_send=k, dim=1, granularity=granularity, process_group=group) + v = shuffle_striped_input(to_send=v, dim=1, granularity=granularity, process_group=group) + k, v = k.contiguous(), v.contiguous() + + # ---------------------------------------------- + # Compute + out, softmax_lse = stripe_flash_attn_forward( + group, + q, k, v, + layer_idx, + softmax_scale=softmax_scale, + granularity=granularity, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_striped_output(out, dim=1, granularity=granularity, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_striped_output(softmax_lse, dim=2, granularity=granularity, process_group=group) + + # ---------------------------------------------- + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.layer_idx = layer_idx + ctx.return_softmax = return_softmax + ctx.group = group + + # ---------------------------------------------- + # Output and return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + layer_idx = ctx.layer_idx + dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, group = ( + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.alibi_slopes, ctx.deterministic, ctx.return_softmax, + ctx.group + ) + + + # ---------------------------------------------- + # Shuffle + dout = shuffle_striped_input( + to_send=dout, dim=1, granularity=ctx.granularity, + process_group=ctx.group + ) + + # ---------------------------------------------- + # Compute + dq, dk, dv = stripe_flash_attn_backward( + ctx.group, + dout, + q, k, v, out, softmax_lse, + layer_idx=layer_idx, + softmax_scale=ctx.softmax_scale, + granularity=ctx.granularity, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + + + # ---------------------------------------------- + # Recover + dq = recover_striped_output(dq, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dk = recover_striped_output(dk, dim=1, granularity=ctx.granularity, process_group=ctx.group) + dv = recover_striped_output(dv, dim=1, granularity=ctx.granularity, process_group=ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + + +def stripe_flash_attn_qkvpacked_func( + qkv, # [B, N, 3, H, D] + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_func( + q, + k, + v, + layer_idx: int, + dropout_p=0.0, + softmax_scale=None, + granularity=1, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + k, + v, + layer_idx, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/dist_ops/test/minfer_ring_test.py b/minference/dist_ops/test/minfer_ring_test.py new file mode 100644 index 0000000..6eccbce --- /dev/null +++ b/minference/dist_ops/test/minfer_ring_test.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import os +import pytest +import random +from typing import Callable +from types import SimpleNamespace + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_by_correct_rate +from minference.dist_ops.minfer_zigzag import minfer_zigzag_func +from minference.dist_ops.minfer_striped import minfer_stripe_func +from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func +from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 4 + +_ATTENTION_IMPLS: dict[str, Callable] = { + "minfer_zigzag": minfer_zigzag_func, + "minfer_stripe": minfer_stripe_func, + "minfer_dr_stripe": minfer_dr_stripe_func, +} + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, + attn_op_name: str, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + attn_op: Callable = _ATTENTION_IMPLS[attn_op_name] + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + q = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = attn_op( + q_local, + k_local, + v_local, + cfg.v_size, + cfg.s_size, + layer_idx=0, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # ----------------- reference: dense Flash-Attention ---------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = minference_flash_attn_func( + q_ref, + k_ref, + v_ref, + cfg.v_size, + cfg.s_size, + causal=True, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + assert check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL),\ + "forward output mismatch" + + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + assert check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL),\ + f"{name} mismatch" + dist.destroy_process_group() + +# ------------- pytest entry-point -------------------------------------------- +@pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") +@pytest.mark.parametrize("seq_len", [131072, 262144, 524288]) +@pytest.mark.parametrize("batch_sz", [1]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("sparsity", [0.9, 0.95]) +@pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) +@pytest.mark.parametrize("use_triton", [True, False]) +@pytest.mark.parametrize("attn_op_name", + ["minfer_zigzag", "minfer_stripe", "minfer_dr_stripe"] +) +def test_sparse_attention_kernels( + seq_len: int, + batch_sz: int, + head_dim: int, + sparsity: float, + num_qkv_head_pair: tuple[int, int], + use_triton: bool, + attn_op_name: str, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + if attn_op_name == "minfer_zigzag" and use_triton: + pytest.skip("minfer_zigzag is not implemented with the Triton path") + + if use_triton: + os.environ['FORCE_TRITON'] = "1" + + cfg = SimpleNamespace( + batch_size=batch_sz, + seq_len=seq_len, + head_dim=head_dim, + sparsity=sparsity, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + ) + # derived sizes used by both candidate and reference kernels + cfg.v_size = [int((1 - cfg.sparsity) * 0.1 * cfg.seq_len)] * cfg.num_qo_heads + cfg.s_size = [int((1 - cfg.sparsity) * 0.2 * cfg.seq_len)] * cfg.num_qo_heads + + print(f"=" * 80) + print(f"Testing {attn_op_name} with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg, attn_op_name), + nprocs=_WORLD_SIZE, + join=True, + ) + + os.environ['FORCE_TRITON'] = "0" diff --git a/minference/dist_ops/test/minfer_ring_test_raw.py b/minference/dist_ops/test/minfer_ring_test_raw.py new file mode 100644 index 0000000..181d86d --- /dev/null +++ b/minference/dist_ops/test/minfer_ring_test_raw.py @@ -0,0 +1,230 @@ +# tests/test_minference_sparse_attention.py +""" +Distributed correctness tests for Minference sparse-attention kernels. + +Run with: + pytest -q -s tests/test_minference_sparse_attention.py +or manually choose GPUs, e.g. + CUDA_VISIBLE_DEVICES=0,1 pytest -q -s … + +The test spawns one process per GPU with torch.multiprocessing, so it does +**not** require `pytest-xdist`. It will be skipped automatically if you have +fewer than two visible CUDA devices. +""" +from __future__ import annotations + +import os +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row, check_by_correct_rate +from minference.dist_ops.minfer_zigzag import minfer_zigzag_func +from minference.dist_ops.minfer_striped import minfer_stripe_func +from minference.dist_ops.minfer_dr_striped import minfer_dr_stripe_func +from minference.ops.pit_sparse_flash_attention_v3 import minference_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 4 + +_ATTENTION_IMPLS: dict[str, Callable] = { + "minfer_zigzag": minfer_zigzag_func, + "minfer_stripe": minfer_stripe_func, + "minfer_dr_stripe": minfer_dr_stripe_func, +} + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, + attn_op_name: str, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + attn_op: Callable = _ATTENTION_IMPLS[attn_op_name] + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = attn_op( + q_local, + k_local, + v_local, + cfg.v_size, + cfg.s_size, + layer_idx=0, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # ----------------- reference: dense Flash-Attention ---------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = minference_flash_attn_func( + q_ref, + k_ref, + v_ref, + cfg.v_size, + cfg.s_size, + causal=True, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + if check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL): + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL) + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +def test_sparse_attention_kernels( + seq_len: int, + batch_sz: int, + head_dim: int, + sparsity: float, + ones: bool, + num_qo_heads: int, + num_kv_heads: int, + attn_op_name: str, + use_triton: bool, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + + cfg = SimpleNamespace( + batch_size=batch_sz, + seq_len=seq_len, + head_dim=head_dim, + sparsity=sparsity, + ones=ones, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + ) + # derived sizes used by both candidate and reference kernels + cfg.v_size = [int((1 - cfg.sparsity) * 0.1 * cfg.seq_len)] * cfg.num_qo_heads + cfg.s_size = [int((1 - cfg.sparsity) * 0.2 * cfg.seq_len)] * cfg.num_qo_heads + + print(f"=" * 80) + print(f"Testing {attn_op_name} with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg, attn_op_name), + nprocs=_WORLD_SIZE, + join=True, + ) + +if __name__ == "__main__": + # Run the test with default parameters + test_sparse_attention_kernels( + seq_len=512 * 1024, + batch_sz=1, + head_dim=128, + sparsity=0.9, + ones=False, + num_qo_heads=4, + num_kv_heads=2, + attn_op_name="minfer_zigzag", + use_triton=True + ) \ No newline at end of file diff --git a/minference/dist_ops/test/moba_ring_test.py b/minference/dist_ops/test/moba_ring_test.py new file mode 100644 index 0000000..9291d79 --- /dev/null +++ b/minference/dist_ops/test/moba_ring_test.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import os +import pytest +import random +import functools +from types import SimpleNamespace + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row, check_by_correct_rate +from minference.dist_ops.moba_zigzag import moba_zigzag_func +from minference.ops.moba import moba_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def skip_if_cuda_oom(test_func): + """Decorator: convert OutOfMemoryError raised inside the test into a skip.""" + @functools.wraps(test_func) + def _wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except torch.OutOfMemoryError: + torch.cuda.empty_cache() + pytest.skip("skipped because the GPU ran out of memory") + return _wrapper + +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + q = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = moba_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = moba_attn_func( + q_ref, k_ref, v_ref, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + assert check_by_correct_rate(final_out, out_ref, ATOL=_ATOL, RTOL=_RTOL),\ + "forward output mismatch" + + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + assert check_by_correct_rate(got, ref, ATOL=_ATOL, RTOL=_RTOL),\ + f"{name} mismatch" + + # torch.testing.assert_close( + # final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward output mismatch" + # ) + # for got, ref, name in zip( + # grads, + # ref_grads, + # ("Q-grad", "K-grad", "V-grad"), + # ): + # torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=f"{name} mismatch") + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +@skip_if_cuda_oom +@pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") +@pytest.mark.parametrize("seq_len", [16384, 32768]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) +@pytest.mark.parametrize("moba_chunk_size", [128, 256]) +@pytest.mark.parametrize("moba_topk", [8, 16]) +def test_moba_kernels( + seq_len: int, + head_dim: int, + num_qkv_head_pair: tuple[int, int], + moba_chunk_size: int, + moba_topk: int, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + + port = str(random.randint(12000, 20000)) + cfg = SimpleNamespace( + batch_size=1, + seq_len=seq_len, + head_dim=head_dim, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + ) + + print(f"=" * 80) + print(f"Testing MoBA (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) \ No newline at end of file diff --git a/minference/dist_ops/test/moba_ring_test_raw.py b/minference/dist_ops/test/moba_ring_test_raw.py new file mode 100644 index 0000000..4943337 --- /dev/null +++ b/minference/dist_ops/test/moba_ring_test_raw.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import os +import random +from typing import Callable, Optional +from types import SimpleNamespace + + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import Tensor + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.ops.moba import moba_attn_varlen, moba_layer, moba_attn_func, moba_attn_varlen_naive +from minference.dist_ops.moba_zigzag import moba_zigzag_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-2 +_RTOL = 1e-2 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # [0, BLK_SZ, 2 * BLK_SZ, 3 * BLK_SZ, ..., seq_len - 1] + rank = dist.get_rank() + world_size = dist.get_world_size() + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = moba_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + out_ref = moba_attn_func( + q_ref, k_ref, v_ref, + global_seq_len=cfg.seq_len, + moba_chunk_size=cfg.moba_chunk_size, + moba_topk=cfg.moba_topk, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + if check_correctness_by_row( + cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL + ): + check_correctness_by_row( + cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[1], ref_grads[1], "K-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[2], ref_grads[2], "V-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +def test_moba_kernels( + seq_len: int = 4096, + batch_size: int = 1, + head_dim: int = 64, + ones: bool = True, + num_qkv_head_pair: tuple[int, int]=(2, 2), + moba_chunk_size: int = 512, + moba_topk: int = 8, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + cfg = SimpleNamespace( + batch_size=batch_size, + seq_len=seq_len, + head_dim=head_dim, + ones=ones, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + ) + + print(f"=" * 80) + print(f"Testing MoBA (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + +if __name__ == "__main__": + test_moba_kernels( + seq_len=16384, + batch_size=1, + head_dim=128, + ones=False, + num_qkv_head_pair=(4, 1), + + moba_chunk_size=128, + moba_topk=8, + ) \ No newline at end of file diff --git a/minference/dist_ops/test/xattn_ring_test.py b/minference/dist_ops/test/xattn_ring_test.py new file mode 100644 index 0000000..3ac5359 --- /dev/null +++ b/minference/dist_ops/test/xattn_ring_test.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import os +import pytest +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func +from minference.ops.xattention_fa import xattn_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-1 +_RTOL = 1e-1 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + q = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = torch.randn( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = xattn_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + xattn_params=cfg.xattn_params, + granularity=128, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + single_machine_params = cfg.xattn_params.copy() + single_machine_params["chunk_size"] = cfg.seq_len // _WORLD_SIZE + out_ref = xattn_flash_attn_func( + q_ref, k_ref, v_ref, + head_indices=list(range(cfg.num_qo_heads)), + xattn_params=single_machine_params, + granularity=128, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + torch.testing.assert_close( + final_out, out_ref, atol=_ATOL, rtol=_RTOL, msg="forward output mismatch" + ) + for got, ref, name in zip( + grads, + ref_grads, + ("Q-grad", "K-grad", "V-grad"), + ): + torch.testing.assert_close(got, ref, atol=_ATOL, rtol=_RTOL, msg=f"{name} mismatch") + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +@pytest.mark.skipif(torch.cuda.device_count() < _WORLD_SIZE, reason="Not enough GPUs") +@pytest.mark.parametrize("seq_len", [131072, 262144, 524288]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("num_qkv_head_pair", [(4, 1), (4, 4)]) +@pytest.mark.parametrize("stride", [16, 32]) +@pytest.mark.parametrize("threshold", [0.9, 0.95]) +def test_xattention_kernels( + seq_len: int, + head_dim: int, + num_qkv_head_pair: tuple[int, int], + stride: int, + threshold: float, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + + port = str(random.randint(12000, 20000)) + xattn_params = { + "stride": stride, + "norm": 1, + "softmax": True, + "threshold": threshold, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + cfg = SimpleNamespace( + batch_size=1, + seq_len=seq_len, + head_dim=head_dim, + num_qo_heads=num_qkv_head_pair[0], + num_kv_heads=num_qkv_head_pair[1], + xattn_params=xattn_params, + ) + + print(f"=" * 80) + print(f"Testing XAttention (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + diff --git a/minference/dist_ops/test/xattn_ring_test_raw.py b/minference/dist_ops/test/xattn_ring_test_raw.py new file mode 100644 index 0000000..473960f --- /dev/null +++ b/minference/dist_ops/test/xattn_ring_test_raw.py @@ -0,0 +1,231 @@ +# tests/test_minference_sparse_attention.py +""" +Distributed correctness tests for Minference sparse-attention kernels. + +Run with: + pytest -q -s tests/test_minference_sparse_attention.py +or manually choose GPUs, e.g. + CUDA_VISIBLE_DEVICES=0,1 pytest -q -s … + +The test spawns one process per GPU with torch.multiprocessing, so it does +**not** require `pytest-xdist`. It will be skipped automatically if you have +fewer than two visible CUDA devices. +""" +from __future__ import annotations + +import os +import random +from types import SimpleNamespace +from typing import Callable + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from minference.ops.utils import set_seed, check_correctness_by_row +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func +from minference.ops.xattention_fa import xattn_flash_attn_func + +# ------------- constants ------------------------------------------------------ +_ATOL = 1e-1 +_RTOL = 1e-1 +_WORLD_SIZE = 4 + +# ------------- helpers -------------------------------------------------------- +def _init_process_group(rank: int, world_size: int, port: str) -> None: + """Initialise NCCL backend for the current worker.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": port, + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank % min(world_size, torch.cuda.device_count())), + "LOCAL_WORLD_SIZE": str(min(world_size, torch.cuda.device_count())), + } + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + + +def _run_worker( + rank: int, + world_size: int, + port: str, + cfg: SimpleNamespace, +) -> None: + """Worker function executed in every spawned GPU process.""" + _init_process_group(rank, world_size, port) + + # Short-hand variables + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dtype = torch.bfloat16 + set_seed(2025 + rank) + + # ----------------- generate identical tensors on every rank -------------- + if rank == 0: + rand_or_one = ( + torch.randn if not cfg.ones else lambda s, **k: torch.ones(*s, **k) + ) + q = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + k = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + v = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + dout = rand_or_one( + (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim), + dtype=dtype, + device=device, + ) + else: + # placeholders that will be overwritten by broadcast + shape_q = (cfg.batch_size, cfg.seq_len, cfg.num_qo_heads, cfg.head_dim) + shape_kv = (cfg.batch_size, cfg.seq_len, cfg.num_kv_heads, cfg.head_dim) + q = torch.empty(shape_q, device=device, dtype=dtype) + k = torch.empty(shape_kv, device=device, dtype=dtype) + v = torch.empty(shape_kv, device=device, dtype=dtype) + dout = torch.empty(shape_q, device=device, dtype=dtype) + + # Make every rank see the same data + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.broadcast(dout, src=0) + + # ----------------- slice local context ----------------------------------- + local_ctx = cfg.seq_len // world_size + sl = slice(rank * local_ctx, (rank + 1) * local_ctx) + + q_local = q[:, sl].clone().detach().requires_grad_() + k_local = k[:, sl].clone().detach().requires_grad_() + v_local = v[:, sl].clone().detach().requires_grad_() + dout_local = dout[:, sl].clone() + + # ----------------- forward / backward on the candidate kernel ------------ + out_local = xattn_zigzag_func( + q_local, k_local, v_local, + layer_idx=0, + xattn_params=cfg.xattn_params, + granularity=128, + ) + torch.autograd.backward(out_local, dout_local) + + # ----------------- gather outputs & grads for reference comparison ------- + out_gather = [torch.empty_like(out_local) for _ in range(world_size)] + dist.all_gather(out_gather, out_local) + final_out = torch.cat(out_gather, dim=1) + + grads = [] + for g in (q_local.grad, k_local.grad, v_local.grad): + tmp = [torch.empty_like(g) for _ in range(world_size)] + dist.all_gather(tmp, g) + grads.append(torch.cat(tmp, dim=1)) + + # --------------------------------------- + if rank == 0: + q_ref = q.detach().clone().requires_grad_() + k_ref = k.detach().clone().requires_grad_() + v_ref = v.detach().clone().requires_grad_() + + single_machine_params = cfg.xattn_params.copy() + single_machine_params["chunk_size"] = cfg.seq_len // _WORLD_SIZE + out_ref = xattn_flash_attn_func( + q_ref, k_ref, v_ref, + head_indices=list(range(cfg.num_qo_heads)), + xattn_params=single_machine_params, + granularity=128, + ) + torch.autograd.backward(out_ref, dout) + ref_grads = (q_ref.grad, k_ref.grad, v_ref.grad) + + # ----------------- assertions ---------------------------------------- + if check_correctness_by_row( + cfg.seq_len, final_out, out_ref, "forward output", ATOL=_ATOL, RTOL=_RTOL + ): + check_correctness_by_row( + cfg.seq_len, grads[0], ref_grads[0], "Q-grad", ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[1], ref_grads[1], "K-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + check_correctness_by_row( + cfg.seq_len, grads[2], ref_grads[2], "V-grad", + ATOL=_ATOL, RTOL=_RTOL + ) + + dist.destroy_process_group() + + +# ------------- pytest entry-point -------------------------------------------- +def test_xattention_kernels( + seq_len: int = 4096, + batch_sz: int = 1, + head_dim: int = 64, + ones: bool = True, + num_qo_heads: int = 2, + num_kv_heads: int = 2, + + stride: int = 16, + threshold: float = 0.9, +): + """ + Compare every sparse kernel against the dense Flash-Attention reference on + both forward pass and input-gradient w.r.t Q/K/V. + """ + port = str(random.randint(12000, 20000)) + xattn_params = { + "stride": stride, + "norm": 1, + "softmax": True, + "threshold": threshold, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + cfg = SimpleNamespace( + batch_size=batch_sz, + seq_len=seq_len, + head_dim=head_dim, + ones=ones, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + xattn_params=xattn_params, + ) + + print(f"=" * 80) + print(f"Testing XAttention (w. Zigzag) with configuration:\n{cfg}") + print(f"=" * 80) + mp.spawn( + _run_worker, + args=(_WORLD_SIZE, port, cfg), + nprocs=_WORLD_SIZE, + join=True, + ) + +if __name__ == "__main__": + # Run the test with default parameters + test_xattention_kernels( + seq_len=512 * 1024, + batch_sz=1, + head_dim=64, + ones=False, + num_qo_heads=4, + num_kv_heads=1, + stride=16, + threshold=0.95, + ) \ No newline at end of file diff --git a/minference/dist_ops/utils.py b/minference/dist_ops/utils.py new file mode 100644 index 0000000..213e403 --- /dev/null +++ b/minference/dist_ops/utils.py @@ -0,0 +1,497 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import math +import torch +import inspect +import operator +import torch.nn.functional as F +import torch.distributed as dist + +import triton +import triton.language as tl + +from functools import reduce, cache +from typing import Optional, Tuple, List, Dict +from torch.distributed.distributed_c10d import P2POp + +PROCESS_GROUPS: Dict[str, dist.ProcessGroup] = {} + + +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +# copy from megatron/core/utils.py +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + +@triton.jit +def _update_out_and_lse_kernel( + Out0, Lse0, Out1, Lse1, + stride_oz0, stride_om0, stride_oh0, stride_od0, + stride_lz0, stride_lm0, stride_lh0, + stride_oz1, stride_om1, stride_oh1, stride_od1, + stride_lz1, stride_lm1, stride_lh1, + num_tokens, + BLOCK_M: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_m = tl.program_id(0) + head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + + if start_m * BLOCK_M >= num_tokens: + return + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_D) + m_mask = offs_m < num_tokens + + o0_ptrs = Out0 + batch_idx * stride_oz0 + head_idx * stride_oh0 + offs_m[:, None] * stride_om0 + offs_d[None, :] * stride_od0 + o1_ptrs = Out1 + batch_idx * stride_oz1 + head_idx * stride_oh1 + offs_m[:, None] * stride_om1 + offs_d[None, :] * stride_od1 + lse0_ptrs = Lse0 + batch_idx * stride_lz0 + head_idx * stride_lh0 + offs_m * stride_lm0 + lse1_ptrs = Lse1 + batch_idx * stride_lz1 + head_idx * stride_lh1 + offs_m * stride_lm1 + + lse0 = tl.load(lse0_ptrs, mask=m_mask, other=float("-inf")) + lse1 = tl.load(lse1_ptrs, mask=m_mask, other=float("-inf")) + o0 = tl.load(o0_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + o1 = tl.load(o1_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + + m_mask &= (lse0 - lse1) < 88.0 + + theta = tl.math.exp(lse0 - lse1) + alpha0 = 1 / (1 + 1 / theta) + alpha1 = 1 / (1 + theta) + o = alpha0[:, None] * o0 + alpha1[:, None] * o1 + lse = lse1 - tl.math.log(alpha1) + + tl.store(o0_ptrs, o.to(Out0.type.element_ty), mask=m_mask[:, None]) + tl.store(lse0_ptrs, lse, mask=m_mask) + + +def _update_out_and_lse_triton( + out: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_tokens, num_heads, 1] + block_out: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + block_lse: torch.Tensor, # [batch_size, num_heads, num_tokens] => [batch_size, num_tokens, num_heads, 1] + step_idx: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + batch_size, num_tokens, num_heads, head_dim = out.shape + block_M = 128 + block_D = head_dim + _update_out_and_lse_kernel[(triton.cdiv(num_tokens, block_M), num_heads, batch_size)]( + out, lse, block_out, block_lse, + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + block_out.stride(0), block_out.stride(1), block_out.stride(2), block_out.stride(3), + block_lse.stride(0), block_lse.stride(1), block_lse.stride(2), + num_tokens, BLOCK_M=block_M, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + return out, lse + + +@torch.jit.script +def _update_out_and_lse_torch( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + step_idx: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, + step_idx: Optional[int] = None, + use_triton_kernel: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + if use_triton_kernel: + _update_out_and_lse = _update_out_and_lse_triton + else: + _update_out_and_lse = _update_out_and_lse_torch + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse, + step_idx=step_idx + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse( + out, lse, block_out, block_lse, + step_idx=step_idx + ) + + return out, lse + + +class RingComm: + def __init__( + self, + process_group: dist.ProcessGroup, + zigzag: bool = False, + ring_list: Optional[list] = None, + ): + self._process_group = process_group + self._ops: List[P2POp] = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + if ring_list is not None: + curr_idx = ring_list.index(self.rank) + self.send_rank = ring_list[(curr_idx + 1) % len(ring_list)] + self.recv_rank = ring_list[(curr_idx - 1 + len(ring_list)) % len(ring_list)] + elif zigzag: + parts = self.world_size // 2 + self.ring_list = [] + for i in range(parts): + self.ring_list.extend([i, self.world_size - i - 1]) + self.revert_rank = self.ring_list.index(self.rank) + offset = ((dist.get_rank() // self.world_size) * self.world_size) + self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset + self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset + else: + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + def send_recv( + self, + to_send: torch.Tensor, + recv_tensor: Optional[torch.Tensor] = None, + step_idx: int = 0, + fwd: int = 1, + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group, + tag=2 * (step_idx * (self.rank + 1)) + fwd, + ) + recv_op = dist.P2POp( + dist.irecv, res, self.recv_rank, group=self._process_group, + tag=2 * (step_idx * (self.rank + 1)) + fwd, + ) + + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + def send_recv_kv_offsets( + self, + k: torch.Tensor, + v: torch.Tensor, + kv_seq_offsets: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + kv_seq_offsets_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + next_kv_seq_offsets = self.send_recv(kv_seq_offsets, kv_seq_offsets_buffer) + + self.commit() + return next_k, next_v, next_kv_seq_offsets + +def shuffle_zigzag_input(to_send: torch.Tensor, + dim: int = 1, + process_group: dist.ProcessGroup = None): + dim %= len(to_send.shape) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_zigzag_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = to_send.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + to_send_slice = to_send[right_slicer].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[right_slicer] = to_send[left_slicer] + to_send_f[left_slicer] = res + else: # A: 0 1, -> 0 7 + to_send_f[left_slicer] = to_send[left_slicer] + to_send_f[right_slicer] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + + +def recover_zigzag_output(to_send: torch.Tensor, + dim: int = 1, + process_group: dist.ProcessGroup = None): + dim %= len(to_send.shape) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + to_send_f = torch.zeros_like(to_send) + + block_seq_len = to_send.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send[left_slicer].contiguous() + else: + to_send_slice = to_send[right_slicer].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[left_slicer] = to_send[right_slicer] + to_send_f[right_slicer] = res + else: + to_send_f[left_slicer] = to_send[left_slicer] + to_send_f[right_slicer] = res + + return to_send_f.contiguous() + + +def shuffle_block_mask_zigzag( + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + num_blocks_per_chunk: int, + group: dist.ProcessGroup, +): + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + # --------------------------------------- + # Shuffle Query chunks + block_mask = shuffle_zigzag_input(to_send=block_mask, dim=-2, process_group=group) # [batch_size, num_qo_heads, num_blocks_local, num_blocks] + + # --------------------------------------- + # Shuffle Key chunks + ring_list = RingComm(group, zigzag=True).ring_list + ring_index = ring_list.index(rank) + + shuffled_block_mask_list = [] + for i in range(world_size): + rank_src = ring_list[(ring_index - i) % world_size] + + curr_chunk_index = 2 * rank_src + rev_chunk_index = (2 * world_size - 1 - curr_chunk_index) + if curr_chunk_index > rev_chunk_index: + curr_chunk_index, rev_chunk_index = rev_chunk_index, curr_chunk_index + + shuffled_block_mask_list.append( + torch.cat( + [ + block_mask[..., curr_chunk_index * num_blocks_per_chunk : (curr_chunk_index + 1) * num_blocks_per_chunk], + block_mask[..., rev_chunk_index * num_blocks_per_chunk : (rev_chunk_index + 1) * num_blocks_per_chunk] + ], dim=-1 + ) + ) + block_mask = torch.stack(shuffled_block_mask_list, dim=0).contiguous() # [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + return block_mask + + +def shuffle_striped_input(to_send: torch.Tensor, # [B, N / W, H, D] + granularity: int = 1, + dim: int = 1, + process_group: dist.ProcessGroup = None): + # 00, 01, 02, 03, 04, 05, 06, 07 => 00, 04, 08, 12, 16, 20, 24, 28 + # 08, 09, 10, 11, 12, 13, 14, 15 => 01, 05, 09, 13, 17, 21, 25, 29 + # 16, 17, 18, 19, 20, 21, 22, 23 => 02, 06, 10, 14, 18, 22, 26, 30 + # 24, 25, 26, 27, 28, 39, 30, 31 => 03, 07, 11, 15, 19, 23, 27, 31 + shape = to_send.shape + dim %= len(shape) + world_size = dist.get_world_size(process_group) + input_reshape = to_send.reshape((*shape[:dim], -1, world_size * granularity, *shape[dim+1:])) + input_list = [x.contiguous() for x in input_reshape.split(granularity, dim=dim+1)] # [N / W / (W * G), W*, G] + output_list = [torch.empty_like(x) for x in input_list] # [W*, N / W / (W * G), G] + + + dist.all_to_all(output_list, input_list, group=process_group) + return torch.stack(output_list, dim=dim).reshape(shape).contiguous() + + +def recover_striped_output(to_send: torch.Tensor, # [B, N / W, H, D] + granularity: int = 1, + dim: int = 1, + process_group: dist.ProcessGroup = None): + # 00, 04, 08, 12, 16, 20, 24, 28 => 00, 01, 02, 03, 04, 05, 06, 07 + # 01, 05, 09, 13, 17, 21, 25, 29 => 08, 09, 10, 11, 12, 13, 14, 15 + # 02, 06, 10, 14, 18, 22, 26, 30 => 16, 17, 18, 19, 20, 21, 22, 23 + # 03, 07, 11, 15, 19, 23, 27, 31 => 24, 25, 26, 27, 28, 39, 30, 31 + shape = to_send.shape + dim %= len(shape) + world_size = dist.get_world_size(process_group) + + input_reshape = to_send.reshape((*shape[:dim], world_size, -1, granularity, *shape[dim+1:])) + input_list = [x.squeeze(dim).contiguous() for x in input_reshape.split(1, dim=dim)] # [W*, N / W / (W * G), G] + output_list = [torch.empty_like(x) for x in input_list] # [N / W / (W * G), W*, G] + + dist.all_to_all(output_list, input_list, group=process_group) + return torch.stack(output_list, dim=dim+1).reshape(shape).contiguous() + +# -------------------------------------------------------------------- +# Double-Ring Related +def get_inner_ring(group: dist.ProcessGroup): + rank = dist.get_rank(group) + local_rank = int(os.environ["LOCAL_RANK"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + assert rank % local_world_size == local_rank + return [i + (rank - local_rank) for i in range(local_world_size)] + + +def get_outer_ring(group: dist.ProcessGroup): + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + local_rank = int(os.environ["LOCAL_RANK"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) + assert rank % local_world_size == local_rank + return [i * local_world_size + local_rank for i in range(world_size // local_world_size)] + + diff --git a/minference/dist_ops/xattn_zigzag.py b/minference/dist_ops/xattn_zigzag.py new file mode 100644 index 0000000..5dfb641 --- /dev/null +++ b/minference/dist_ops/xattn_zigzag.py @@ -0,0 +1,510 @@ +import os +import math +import torch +import triton +import torch.distributed as dist +from typing import List, Tuple, Dict, Any, Optional + +from .utils import ( + RingComm, update_out_and_lse, + shuffle_zigzag_input, recover_zigzag_output, + shuffle_block_mask_zigzag, +) + +from minference.ops.utils import use_triton +from minference.ops.op_utils.xattn_utils import LN2, find_blocks_chunked +from minference.ops.op_utils.vertical_slash_utils import convert_blockmask +from minference.ops.xattention_fa import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd, triton_block_attn_fwd, triton_block_attn_bwd + + +def xattn_zigzag_estimate( + query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) + key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) + block_size, + stride, + norm=1, + softmax=True, + threshold=0.9, + select_mode="inverse", + use_triton=True, + causal=True, + kdb: int = 1, + keep_sink=False, + keep_recent=False, + group: dist.group = None, +) -> torch.Tensor: + batch_size, num_kv_head, k_len_local, head_dim = key_states.shape + batch_size, num_q_head, q_len_local, head_dim = query_states.shape + + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + k_gather_list = [torch.empty_like(key_states) for _ in range(world_size)] + dist.all_gather(k_gather_list, key_states.contiguous(), group=group) + k_gathered = torch.cat(k_gather_list, dim=2) + k_len = k_gathered.shape[2] + + if num_q_head > num_kv_head: + k_gathered = torch.repeat_interleave(k_gathered.contiguous(), num_q_head // num_kv_head, dim=1) + + chunk_size = q_len_local // 2 + q_chunk_num = 2 + q_block_num = q_len_local // block_size + q_block_num_per_chunk = chunk_size // block_size + + # assert num_kv_head == num_q_head + attn_sum_list = [] + simple_mask_list = [] + + num_strides_in_k = k_len // stride + num_strides_per_chunk = chunk_size // stride + num_strides_per_block = block_size // stride + num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block + + attn_weight_slices = [None, None] + for chunk_idx in range(q_chunk_num): + global_chunk_idx = rank * 2 + chunk_idx + + # Local start index + q_chunk_start = chunk_idx * chunk_size + q_chunk_end = (chunk_idx + 1) * chunk_size + + # Global start index (stride-level) + q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk + q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk + + # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) + # (i.e. the attention sum of each SxS stride block) + # This step is agnostic to block size and just computes the attention sum in each stride block + attn_weight_slice = flat_group_gemm_fuse_reshape( + # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True + query_states[:, :, q_chunk_start : q_chunk_end, :,], + k_gathered, + stride, + q_chunk_start_stride_global, q_chunk_end_stride_global, + is_causal=causal, + ) + attn_weight_slices[chunk_idx] = attn_weight_slice + del k_gathered, k_gather_list + + for chunk_idx in range(q_chunk_num): + global_chunk_idx = rank * 2 + chunk_idx + + # Local start index + q_chunk_start = chunk_idx * chunk_size + q_chunk_end = (chunk_idx + 1) * chunk_size + + # Global start index (block-level) + q_block_start = global_chunk_idx * q_block_num_per_chunk + q_block_end = (global_chunk_idx + 1) * q_block_num_per_chunk + + # Global start index (stride-level) + q_chunk_start_stride_global = global_chunk_idx * num_strides_per_chunk + q_chunk_end_stride_global = (global_chunk_idx + 1) * num_strides_per_chunk + + attn_weight_slice = attn_weight_slices[chunk_idx] + + # (batch_size, num_heads, q_block_num, k_block_num), + attn_sum = softmax_fuse_block_sum( + attn_weight_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) + num_strides_per_block, + min(4096, num_strides_per_block), + q_chunk_start_stride_global, q_chunk_end_stride_global, + num_strides_in_k, + 1 / LN2 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + + # (batch_size, head_num, num_blocks_per_chunk, block_num) + simple_mask = find_blocks_chunked( + attn_sum, + global_chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + del attn_weight_slice + if causal: + simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( + torch.tril( + torch.ones( + q_block_num_per_chunk, q_block_num_per_chunk, + dtype=bool, device=key_states.device + ), + diagonal=0, + ), + simple_mask[:, :, :, q_block_start:q_block_end], + False, + ) + simple_mask[:, :, :, q_block_end:] = 0 + if keep_sink: + simple_mask[:, :, 0, :] = True + if keep_recent: + eye_matrix = torch.eye(q_block_num_per_chunk, device=simple_mask.device, dtype=bool) + eye_matrix_expanded = ( + eye_matrix.unsqueeze(0) + .unsqueeze(0) + .expand(1, num_kv_head, q_block_num_per_chunk, q_block_num_per_chunk) + ) + simple_mask[:, :, :, q_block_start:q_block_end] = torch.where( + eye_matrix_expanded, True, simple_mask[:, :, :, q_block_start:q_block_end] + ) + + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + attn_sums = torch.cat(attn_sum_list, dim=-2) + simple_masks = torch.cat(simple_mask_list, dim=-2) # (batch_size, head_num, q_local_block_num, k_global_block_num) + return attn_sums, simple_masks + +def xattn_zigzag_forward( + process_group: dist.ProcessGroup, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + layer_idx: int, + softmax_scale: float, + granularity: int = 128, + block_idx: Optional[torch.Tensor] = None, + block_cnt: Optional[torch.Tensor] = None, +): + comm = RingComm(process_group, zigzag=True) + out, lse = None, None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + # [batch_size, num_qo_heads, num_blocks_local, num_blocks_local] + block_mask_step = block_mask[step] + block_causal = step == 0 + + if use_triton(): + # TODO: block_mask here needs to be converted to block_idx before passing to triton + block_out, block_lse = triton_block_attn_fwd( + q, k, v, + block_idx=block_idx[step], block_cnt=block_cnt[step], + softmax_scale=softmax_scale, + granularity=granularity, + causal=block_causal, + step=step, + ) + else: + block_out, block_lse = block_attn_fwd( + q, k, v, + block_mask=block_mask_step, + softmax_scale=softmax_scale, + granularity=granularity, + causal=block_causal, + step_idx=step, + ) + + out, lse = update_out_and_lse(out, lse, block_out, block_lse, use_triton_kernel=False) + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + +def xattn_zigzag_backward( + process_group: dist.ProcessGroup, + dout: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + out: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + layer_idx: int, + softmax_scale: float, + block_mask: torch.Tensor, # [world_size, batch_size, num_qo_heads, num_blocks, num_blocks] + granularity: int = 128, + block_idx: Optional[torch.Tensor] = None, # [world_size, batch_size, num_qo_heads, num_blocks_local, num_blocks] + block_cnt: Optional[torch.Tensor] = None, # [world_size, batch_size, num_qo_heads, num_blocks_local] +): + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + block_causal = step == 0 + block_mask_step = block_mask[step] + + # -------------------------------- + # Block Mask + if use_triton(): + step_dq, step_dk, step_dv = triton_block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_idx[step], block_cnt[step], + granularity=granularity, + deterministic=False, + causal=block_causal, + step=step, + ) + else: + step_dq, step_dk, step_dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, softmax_scale, + block_mask_step, + granularity=granularity, + deterministic=False, + causal=block_causal, + ) + + # Update dQ, dK, dV + if step == 0: + # TODO: check if float32 is necessary + dq = step_dq.to(torch.float32) + dk = step_dk.to(torch.float32) + dv = step_dv.to(torch.float32) + else: + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + dq += step_dq + dk += step_dk + dv += step_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +class XAttnZigzagFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx, + xattn_params, # Dict[str, Any] + granularity, + causal, + softmax_scale, + return_softmax, + deterministic, + group, + ): + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + + # ---------------------------------------------- + # Index Building + # block_mask [batch_size, num_qo_heads, num_blocks_local, num_blocks] + _, block_mask = xattn_zigzag_estimate( + q.transpose(1, 2), k.transpose(1, 2), + block_size=granularity, + **xattn_params + ) + + # ------------------------------------------------------------------ + # QKV Shuffling + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + + # ------------------------------------------------------------------ + # Index Shuffling + block_mask = shuffle_block_mask_zigzag( + block_mask, num_blocks_per_chunk=q.shape[1] // 2 // granularity, + group=group + ).to(q.device) + if use_triton(): + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + else: + block_idx, block_cnt = None, None + block_mask = block_mask.contiguous() + + # ---------------------------------------------- + # Compute + out, softmax_lse = xattn_zigzag_forward( + group, + q, k, v, + block_mask, + layer_idx, + softmax_scale, + granularity=granularity, + block_idx=block_idx, block_cnt=block_cnt, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + + # ------------------------------- + # Variale Saving + if use_triton(): + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, block_idx, block_cnt) + else: + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask) + ctx.softmax_scale = softmax_scale + ctx.granularity = granularity + ctx.group = group + ctx.layer_idx = layer_idx + + # ------------------------------- + # Recover outputs + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + if use_triton(): + q, k, v, out, softmax_lse, block_mask, block_idx, block_cnt = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, block_mask = ctx.saved_tensors + block_idx, block_cnt = None, None + softmax_scale = ctx.softmax_scale + granularity = ctx.granularity + layer_idx = ctx.layer_idx + group = ctx.group + + + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = xattn_zigzag_backward( + group, + dout, q, k, v, + out, softmax_lse, + layer_idx, + softmax_scale, + block_mask, + granularity, + block_idx=block_idx, block_cnt=block_cnt, + ) + + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def xattn_zigzag_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + layer_idx: int, + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + return XAttnZigzagFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + layer_idx, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + group, + ) + + +def xattn_zigzag_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_heads, head_dim] + layer_idx: int, + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +): + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return XAttnZigzagFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + layer_idx, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + group, + ) + + +def xattn_zigzag_func( # the one used for nnscaler training + q: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_heads, head_dim] + layer_idx: int, + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + group: dist.ProcessGroup = None, +) -> torch.Tensor: + assert causal + assert dropout_p == 0 + assert window_size == (-1, -1) + assert alibi_slopes is None + assert not deterministic + + return XAttnZigzagFunc.apply( + q, k, v, + layer_idx, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, + group, + ) diff --git a/minference/dist_ops/zigzag_attention.py b/minference/dist_ops/zigzag_attention.py new file mode 100644 index 0000000..a5e4ab2 --- /dev/null +++ b/minference/dist_ops/zigzag_attention.py @@ -0,0 +1,412 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +import os +import copy +import torch +import torch.distributed as dist + +from time import perf_counter +from typing import List, Tuple, Dict +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +from .utils import ( + RingComm, update_out_and_lse, shuffle_zigzag_input, + recover_zigzag_output, get_default_args +) + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, # [B, S, H, D] + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group, zigzag=True) + + bsz, seq_len, num_heads, head_dim = q.shape + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + if step == 0: + # Do softmax(QK^T / sqrt(d_k))V on the currently hold K and V + # and record the output and the LSE + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, k, v, out, + layer_idx: int, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + bsz, seq_len, num_heads, head_dim = q.shape + + kv_comm = RingComm(process_group, zigzag=True) + d_kv_comm = RingComm(process_group, zigzag=True) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout_, q_, k_, v_, out_, softmax_lse_, causal_): + seqlen_q = q_.shape[1] + seqlen_kv = k_.shape[1] + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout_, + "q": q_, + "k": k_, + "v": v_, + "out": out_, + "softmax_lse": softmax_lse_, + "dq": dq_buffer[:, :seqlen_q], + "dk": dk_buffer[:, :seqlen_kv], + "dv": dv_buffer[:, :seqlen_kv], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal_, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + params.update({"rng_state": torch.zeros((2, ), dtype=torch.int64, device=q.device)}) + + _flash_attn_backward(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + # ----------------------------------------------------------- + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal_=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal_=False) + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal_=False) + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dq += dq_buffer + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + dk += dk_buffer + dv += dv_buffer + + # ----------------------------------------------------------- + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv( + dk, dv, dk_comm_buffer, dv_comm_buffer, + ) + + d_kv_comm.wait() + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + assert alibi_slopes is None + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + + # ---------------------------------------------- + # Shuffle + q = shuffle_zigzag_input(to_send=q, dim=1, process_group=group) + k = shuffle_zigzag_input(to_send=k, dim=1, process_group=group) + v = shuffle_zigzag_input(to_send=v, dim=1, process_group=group) + k, v = k.contiguous(), v.contiguous() + + # ---------------------------------------------- + # Compute + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, k, v, + layer_idx, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + + # ---------------------------------------------- + # Recover outputs + recovered_out = recover_zigzag_output(out, dim=1, process_group=group) + if return_softmax: + recovered_softmax_lse = recover_zigzag_output(softmax_lse, dim=2, process_group=group) + + # ------------------------------ + # Saving tensors + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + ctx.layer_idx = layer_idx + ctx.return_softmax = return_softmax + + # ---------------------------------------------- + # Output and return + if return_softmax: + return (recovered_out, recovered_softmax_lse, None) + return recovered_out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + layer_idx = ctx.layer_idx + dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, group = ( + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.alibi_slopes, ctx.deterministic, ctx.return_softmax, + ctx.group + ) + + # ---------------------------------------------- + # Shuffle + dout = shuffle_zigzag_input(to_send=dout, dim=1, process_group=group) + + # ---------------------------------------------- + # Compute + dq, dk, dv = zigzag_ring_flash_attn_backward( + group, + dout, + q, k, v, out, + layer_idx, + softmax_lse, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + ) + + + # ---------------------------------------------- + # Recover + dq = recover_zigzag_output(dq, dim=1, process_group=group) + dk = recover_zigzag_output(dk, dim=1, process_group=group) + dv = recover_zigzag_output(dv, dim=1, process_group=group) + + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + layer_idx, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + layer_idx, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/minference/ops/moba.py b/minference/ops/moba.py new file mode 100644 index 0000000..d9d860a --- /dev/null +++ b/minference/ops/moba.py @@ -0,0 +1,630 @@ +"""A clean version of moba implementation for educational purposes""" +import math +import torch + +from einops import rearrange +from typing import Union, Tuple, Callable, Optional +from flash_attn import flash_attn_varlen_func, flash_attn_func +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) + +from .op_utils.moba_utils import calc_chunks + +def hf_to_fa(x: torch.Tensor): + """ + Args: + x (torch.Tensor): [batch, heads, seqlen, head_dim] + + Returns: + torch.Tensor: [batch * seqlen, heads, head_dim] + """ + return x.permute(0, 2, 1, 3).reshape(-1, x.shape[1], x.shape[3]) + + +def moba_attn_varlen_naive( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + moba_chunk_size: int, + moba_topk: int, +) -> torch.Tensor: + """Implement the moba brute-force setting for reference + + Args: + q (torch.Tensor): [seqlen, head, head_dim] + k (torch.Tensor): [seqlen, head, head_dim] + v (torch.Tensor): [seqlen, head, head_dim] + cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn + max_seqlen (int): the max sequence length of the batch, same definition in flash attn + + Returns: + attn_output (torch.Tensor): [seqlen, head, head_dim] + """ + + # qkv shape = [ S, H, D ] + batch = cu_seqlens.numel() - 1 + softmax_scale = q.shape[-1] ** (-0.5) + + o = torch.zeros_like(q) + for batch_idx in range(batch): + batch_start = cu_seqlens[batch_idx].item() + batch_end = cu_seqlens[batch_idx + 1].item() + # get qkv of this batch + q_ = q[batch_start:batch_end] + k_ = k[batch_start:batch_end] + v_ = v[batch_start:batch_end] + o_ = o[batch_start:batch_end] + # calc key gate weight + key_gate_weight = [] + batch_size = batch_end - batch_start + num_block = math.ceil(batch_size / moba_chunk_size) + for block_idx in range(0, num_block): + block_start = block_idx * moba_chunk_size + block_end = min(batch_size, block_start + moba_chunk_size) + key_gate_weight.append(k_[block_start:block_end].mean(dim=0, keepdim=True)) + key_gate_weight = torch.cat(key_gate_weight, dim=0) # [ N, H, D ] + # calc & mask gate + # use fp32 to avoid precision issue in bf16 + q_ = q_.type(torch.float32) + key_gate_weight = key_gate_weight.type(torch.float32) + gate = torch.einsum("shd,nhd->hsn", q_, key_gate_weight) # [ H, S, N ] + key_gate_weight = key_gate_weight.type_as(k) + q_ = q_.type_as(k) + for i in range(num_block): + # select the future Qs that can attend to KV chunk i + gate[:, : (i + 1) * moba_chunk_size, i] = float("-inf") + gate[:, i * moba_chunk_size : (i + 1) * moba_chunk_size, i] = float("inf") + # gate_top_k_idx = gate_top_k_val = [ H S K ] + gate_top_k_val, gate_top_k_idx = torch.topk( + gate, k=min(moba_topk, num_block), dim=-1, largest=True, sorted=False + ) + gate_top_k_val, _ = gate_top_k_val.min(dim=-1) # [ H, S ] + need_attend = gate >= gate_top_k_val.unsqueeze(-1) + # add gate_idx_mask in case of there is cornercases of same topk val been selected + gate_idx_mask = torch.zeros( + need_attend.shape, dtype=torch.bool, device=q.device + ) + gate_idx_mask = gate_idx_mask.scatter_(dim=-1, index=gate_top_k_idx, value=True) + need_attend = torch.logical_and(need_attend, gate_idx_mask) + gate[need_attend] = 0 + gate[~need_attend] = -float("inf") + gate = gate.repeat_interleave(moba_chunk_size, dim=-1)[ + :, :, :batch_size + ] # [ H, S, S ] + gate.masked_fill_( + torch.ones_like(gate, dtype=torch.bool).tril().logical_not(), -float("inf") + ) + # print(f"moba_naive | gate ({gate.shape}): {gate}") + + # calc qk = qk^t + q_ = q_.type(torch.float32) + k_ = k_.type(torch.float32) + v_ = v_.type(torch.float32) + qk = torch.einsum("xhd,yhd->hxy", q_, k_) + # mask + qk += gate + qk *= softmax_scale + # calc o + p = qk.softmax(dim=-1) + o_ += torch.einsum("hxy,yhd->xhd", p, v_) + o = o.type_as(q) + + return o + + + + +class MixedAttention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + return_lse, + ): + ctx.max_seqlen = max_seqlen + ctx.moba_chunk_size = moba_chunk_size + ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5) + + # self attn + self_attn_out_sh, self_attn_lse_hs, _, _ = ( + _flash_attn_varlen_forward( + q=q, + k=k, + v=v, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + ) + ) + + moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + + # convert lse shape hs -> sh ( follow the legacy mix attn logic ) + self_attn_lse_sh = self_attn_lse_hs.t().contiguous() + moba_attn_lse = moba_attn_lse_hs.t().contiguous() + + # output buffer [S, H, D], same shape as q + output = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + + # flatten vS & H for index ops + output_2d = output.view(-1, q.shape[2]) + + # calc mixed_lse + # minus max lse to avoid exp explosion + max_lse_1d = self_attn_lse_sh.view(-1) + max_lse_1d = max_lse_1d.index_reduce( + 0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax" + ) + self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh) + moba_attn_lse = ( + moba_attn_lse.view(-1) + .sub(max_lse_1d.index_select(0, moba_q_sh_indices)) + .reshape_as(moba_attn_lse) + ) + + mixed_attn_se_sh = self_attn_lse_sh.exp() + moba_attn_se = moba_attn_lse.exp() + + mixed_attn_se_sh.view(-1).index_add_( + 0, moba_q_sh_indices, moba_attn_se.view(-1) + ) + mixed_attn_lse_sh = mixed_attn_se_sh.log() + + # add attn output + factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [ vS, H ] + self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1) + output_2d += self_attn_out_sh.reshape_as(output_2d) + + # add moba output + mixed_attn_lse = ( + mixed_attn_lse_sh.view(-1) + .index_select(0, moba_q_sh_indices) + .view_as(moba_attn_lse) + ) + factor = (moba_attn_lse - mixed_attn_lse).exp() # [ vS, H ] + moba_attn_out = moba_attn_out * factor.unsqueeze(-1) + raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1]) + output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out) + output = output.to(q.dtype) + + + # add back max lse + mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh) + + + ctx.save_for_backward( + output, + mixed_attn_lse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) + ctx.return_lse = return_lse + + if return_lse: + return output, mixed_attn_lse_sh + else: + return output + + @staticmethod + def backward(ctx, d_output, *args): + + max_seqlen = ctx.max_seqlen + moba_chunk_size = ctx.moba_chunk_size + softmax_scale = ctx.softmax_scale + + ( + output, + mixed_attn_vlse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) = ctx.saved_tensors + + d_output = d_output.contiguous() + + dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) + _flash_attn_varlen_backward( + dout=d_output, + q=q, + k=k, + v=v, + out=output, + softmax_lse=mixed_attn_vlse_sh.t().contiguous(), + dq=dq, + dk=dk, + dv=dv, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + dropout_p=0.0, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + alibi_slopes=None, + deterministic=True, + ) + + headdim = q.shape[-1] + d_moba_output = ( + d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + moba_output = ( + output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + + mixed_attn_vlse = ( + mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1) + ) + + dmq = torch.zeros_like(moba_q, dtype=moba_q.dtype, device=moba_q.device) + dmk = torch.zeros_like(moba_kv[:, 0], dtype=moba_kv.dtype, device=moba_kv.device) + dmv = torch.zeros_like(moba_kv[:, 1], dtype=moba_kv.dtype, device=moba_kv.device) + _flash_attn_varlen_backward( + dout=d_moba_output, + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + out=moba_output, + softmax_lse=mixed_attn_vlse, + dq=dmq, + dk=dmk, + dv=dmv, + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + alibi_slopes=None, + deterministic=True, + ) + + dmkv = torch.stack((dmk, dmv), dim=1) + return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None, None + + +def moba_attn_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + moba_chunk_size: int, + moba_topk: int, + return_lse: bool=False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """An efficient version of moba implementation with triton kernels and flash-attn, the core logic: + 1. Calculate the chunks and the number of chunks, n = floor(data_size / chunk_size) + - tokens in the tail chunk are reserved for self attn + - tokens in other chunks will be processed in later steps + 2. K in each chunk will calculate mean value as the representative k, and Q will attend to these representative + k to get the gate logit, which will be used to select topk chunks + 3. Select the topk chunks and get the dense q for each kv chunk pair and do the varlen attention + 4. Combine the varlen attn and self attn results via online softmax to get the final result + + Args: + q (torch.Tensor): [seqlen, head, head_dim] + k (torch.Tensor): [seqlen, head, head_dim] + v (torch.Tensor): [seqlen, head, head_dim] + cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn + max_seqlen (int): the max sequence length of the batch, same definition in flash attn + + Returns: + attn_output (torch.Tensor): [seqlen, head, head_dim] + """ + head_group_size = q.shape[1] // k.shape[1] + if head_group_size > 1: + k = torch.repeat_interleave(k, head_group_size, dim=1) + v = torch.repeat_interleave(v, head_group_size, dim=1) + + # --------------------------------------------------------------------------------------------- + kv = torch.stack((k, v), dim=1) # stack along a new dimension -> [S, 2, H, D] + + """ some basic variables """ + # qkv shape = [ S, H, D ] + seqlen, num_head, head_dim = q.shape + + """ prepare chunk meta """ + ( + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + ) = calc_chunks(cu_seqlens, moba_chunk_size) + + # we will adjust selective topk to moba_topk - 1, as the last chunk is always chosen + moba_topk = min(moba_topk - 1, num_filtered_chunk) + need_moba_attn = moba_topk > 0 + + # corner case: if no moba attn needed, just return self attn + if not need_moba_attn: + return flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=True + ) + + self_attn_cu_seqlen = cu_chunk + + # filtered_kv is a dense matrix that only contains filtered chunk of kv + filtered_kv_indices = torch.arange( + 0, moba_chunk_size, dtype=torch.int32, device=q.device + )[None, :].repeat(num_filtered_chunk, 1) + filtered_kv_indices += cu_chunk[filtered_chunk_indices][:, None] + + # select the elements of KV corresponding to all chunks that are not filtered out + filtered_kv = kv.index_select(0, filtered_kv_indices.view(-1)) + + """ calc key_gate_weight and gate """ + # key_gate_weight [ F_N_CHUNK, HEAD, HEAD_DIM ] + key_gate_weight = ( + filtered_kv[:, 0] # K + .view(num_filtered_chunk, moba_chunk_size, num_head, head_dim) + .mean(dim=1) # mean pooling along chunk size + .float() + ) + q = q.type(torch.float32) # float logit on the fly for better gate logit perception + key_gate_weight = key_gate_weight.type( + torch.float32 + ) # float logit for better gate logit perception + gate = torch.einsum( + "nhd,shd->nhs", key_gate_weight, q + ) # gate [ F_N_CHUNK, HEAD, SEQ ] + key_gate_weight = key_gate_weight.type_as(k) + q = q.type_as(k) + + # pose process gate, masking unchosen batch and apply causal mask to current chunk + gate_seq_idx = torch.arange(0, seqlen, device=q.device, dtype=torch.int32)[ + None, : + ].repeat(num_filtered_chunk, 1) + chunk_end = cu_chunk[filtered_chunk_indices + 1] + batch_end = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1] + gate_chunk_end_mask = gate_seq_idx < chunk_end[:, None] + gate_batch_end_mask = gate_seq_idx >= batch_end[:, None] + gate_inf_mask = gate_chunk_end_mask | gate_batch_end_mask + gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf")) + + """ find moba q that needs moba attn """ + # find topk chunks + # gate_mask [ N_CHUNK, HEAD, SEQ ], true indicates that needs attention + _, gate_top_k_idx = torch.topk(gate, k=moba_topk, dim=0, largest=True, sorted=False) + + # apply causal mask + gate_mask = torch.logical_not(gate.isinf()) + + # select topk chunks + gate_idx_mask = torch.zeros(gate_mask.shape, dtype=torch.bool, device=q.device) + gate_idx_mask = gate_idx_mask.scatter_(dim=0, index=gate_top_k_idx, value=True) + + # gate_mask has the shape [ N_CHUNK, HEAD, SEQ ]. + # For each chunk, the sequence-dimension indices will be True if it belongs to the top-K chunks + gate_mask = torch.logical_and(gate_mask, gate_idx_mask) + # --------------------------------------------------------------------------------------------- + + + # varlen trick: combining all q index that needs moba attn + # the result will be like [ C0H0 ][ C0H1 ][ C0H2 ][ ... ][ CnHm ] + # torch.nonzero (as_tuple=True): Returns a tuple of 1-D tensors, one for each dimension in input, each containing the indices (in that dimension) of all non-zero elements of input . + # if input has n-dimension, the resulting tuple will have n tensors of size z, where z is the number of non-zero elements in input. + # (i-th values of all n tuple elements represent the indices of the i-th non-zero element in each dimension) + # using index [-1] => indices of HS (combined sequence) dimension that contains non-zero elements + moba_q_indices = gate_mask.reshape(gate_mask.shape[0], -1) # [ N, HS ] + + moba_q_indices = moba_q_indices.nonzero(as_tuple=True)[-1] # [HS indices] * N (total size: all non-zero elements in HS dimension) + + + # moba_seqlen_q indicates that how many q chunks are selected for each kv chunk - head + moba_seqlen_q = gate_mask.sum(dim=-1).flatten() + + # select all q that needs moba attn based on the moba_q_indices + moba_q = rearrange(q, "s h d -> ( h s ) d").index_select( + 0, moba_q_indices + ) # [ selected_S, D ] + moba_q = moba_q.unsqueeze(1) + + # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q + moba_q_sh_indices = moba_q_indices % seqlen * num_head + moba_q_indices // seqlen + + """ prepare moba kv """ + # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q + + # cut off zero experts + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + zero_expert_count = q_zero_mask.sum() + + # only keep the kv that has q select > 0 + if zero_expert_count > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + + + # moba cu_seqlen for flash attn + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + # ----------------------------------------------- + moba_kv = rearrange(filtered_kv, "s x h d -> h s x d") # here `x` only stands for a dimension (stack dimension for KV) + + moba_kv = moba_kv.split(moba_chunk_size, dim=1) + moba_kv = torch.cat(moba_kv, dim=0) # [num_selected_chunks, H x S // moba_chunk_size, D] + + if zero_expert_count > 0: + assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count + moba_kv = moba_kv[ + valid_expert_mask + ] # cut off zero Q expert from kv , or the grad may be nan + + moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2) + + moba_cu_seqlen_kv = ( + torch.arange( + 0, + num_filtered_chunk * num_head + 1 - zero_expert_count, + dtype=torch.int32, + device=q.device, + ) + * moba_chunk_size + ) + + # Shape check + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"moba_cu_seqlen_kv.shape != moba_cu_seqlen_q.shape {moba_cu_seqlen_kv.shape} != {moba_cu_seqlen_q.shape}" + + # Wrapping up the flash attn call and online softmax dlse inside MixedAttention class + return MixedAttention.apply( + q, k, v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + return_lse + ) + + +def moba_attn_func( + q: torch.Tensor, # [batch, q_len, q_heads, head_dim] + k: torch.Tensor, + v: torch.Tensor, + global_seq_len: int, + moba_chunk_size: int, + moba_topk: int, + **kwargs, +): + batch_size = q.shape[0] + cu_seqlens = torch.cumsum( + torch.tensor([0] + [global_seq_len] * batch_size, device=q.device), + dim=0, + dtype=torch.int32, + ) + + q_3d, k_3d, v_3d = \ + q.reshape(-1, q.shape[2], q.shape[3]), \ + k.reshape(-1, k.shape[2], k.shape[3]), \ + v.reshape(-1, v.shape[2], v.shape[3]) + + # output: [batch_size, global_seq_len, q_heads, head_dim] + return moba_attn_varlen( + q_3d, k_3d, v_3d, + cu_seqlens, + global_seq_len, + moba_chunk_size, + moba_topk, + ).view(q.shape) + + +def moba_layer( + moba_impl: Callable, + moba_chunk_size: int, + moba_topk: int, + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + """ + Args: + query (torch.Tensor): [batch, q_heads, q_len, head_dim] + key (torch.Tensor): [batch, kv_heads, kv_len, head_dim] + value (torch.Tensor): [batch, kv_heads, kv_len, head_dim] + + Returns: + attn_output (torch.Tensor): [batch, q_len, q_heads, head_dim] + attn_weights (None): not needed + """ + assert module.is_causal + batch, q_heads, q_len, head_dim = query.shape + _, kv_heads, kv_len, _ = key.shape + if q_len == kv_len: + # prefill phase + query = hf_to_fa(query) + key = hf_to_fa(key) + value = hf_to_fa(value) + kv_replicas = q_heads // kv_heads + key = torch.repeat_interleave(key, kv_replicas, dim=1) + value = torch.repeat_interleave(value, kv_replicas, dim=1) + cu_seqlens_k = torch.cumsum( + torch.tensor([0] + [kv_len] * batch, device=query.device), + dim=0, + dtype=torch.int32, + ) + out = moba_impl( + q=query, + k=key, + v=value, + cu_seqlens=cu_seqlens_k, + max_seqlen=kv_len, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + ) + else: + # decode phase + # TODO release paged attn implementation + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + out = flash_attn_func(query, key, value, dropout, scaling, True) + return out, None diff --git a/minference/ops/op_utils/__init__.py b/minference/ops/op_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minference/ops/op_utils/moba_utils.py b/minference/ops/op_utils/moba_utils.py new file mode 100644 index 0000000..d829a83 --- /dev/null +++ b/minference/ops/op_utils/moba_utils.py @@ -0,0 +1,365 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import torch +import torch.distributed as dist + +from functools import lru_cache +from dataclasses import dataclass + +def tensor_4d_to_3d(tensor: torch.Tensor) -> torch.Tensor: + """Convert a 4D tensor to a 3D tensor by collapsing the first two dimensions.""" + if tensor.ndim != 4: + raise ValueError("Input tensor must be 4D.") + return tensor.reshape(tensor.shape[0] * tensor.shape[1], tensor.shape[2], tensor.shape[3]) + + +@dataclass +class MoBAConfig: + moba_chunk_size: int + moba_topk: int + +def shuffle_input_all( + to_send: torch.Tensor, # [S, H, D] + seq_offset: torch.Tensor, # [2] + gate_mask: torch.Tensor = None, # [num_chunks, H, S] + process_group: dist.ProcessGroup = None + ): + orig_ndim = to_send.ndim + if orig_ndim == 3: to_send = to_send.unsqueeze(0) + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + block_seq_len = to_send.shape[1] // 2 + + seq_offset_val = seq_offset.detach().cpu().item() + seq_offsets = torch.Tensor([seq_offset_val, seq_offset_val + block_seq_len]).to(to_send.device) + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + to_send_gate_mask = torch.zeros_like(gate_mask) + to_send_offset = seq_offsets[1] + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + to_send_slice = to_send[:, block_seq_len:].contiguous() + to_send_gate_mask_slice = gate_mask[..., block_seq_len:].contiguous() + + res = torch.zeros_like(to_send_slice) + res_gate_mask = torch.zeros_like(to_send_gate_mask_slice) + res_offset= torch.zeros_like(to_send_offset) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + send_gate_mask_op = dist.P2POp( + dist.isend, to_send_gate_mask_slice, src_rank, group=process_group + ) + send_offset_op = dist.P2POp( + dist.isend, to_send_offset, src_rank, group=process_group + ) + _ops.append(send_op) + _ops.append(send_gate_mask_op) + _ops.append(send_offset_op) + + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + recv_gate_mask_op = dist.P2POp( + dist.irecv, res_gate_mask, src_rank, group=process_group + ) + recv_offset_op = dist.P2POp( + dist.irecv, res_offset, src_rank, group=process_group + ) + _ops.append(recv_op) + _ops.append(recv_gate_mask_op) + _ops.append(recv_offset_op) + + # response = dist.dist.batch_isend_irecv(_ops) + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + + to_send_gate_mask[..., block_seq_len:] = gate_mask[..., :block_seq_len] + to_send_gate_mask[..., :block_seq_len] = res_gate_mask + + seq_offsets[1] = seq_offsets[0] + seq_offsets[0] = res_offset + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + + to_send_gate_mask[..., :block_seq_len] = gate_mask[..., :block_seq_len] + to_send_gate_mask[..., block_seq_len:] = res_gate_mask + + seq_offsets[1] = res_offset + + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + return ( + to_send_f if orig_ndim != 3 else to_send_f.squeeze(0), + seq_offsets, + to_send_gate_mask, + ) + + +def shuffle_input_only( + to_send: torch.Tensor, # [S, H, D] + process_group: dist.ProcessGroup = None + ) -> torch.Tensor: + orig_ndim = to_send.ndim + if orig_ndim == 3: to_send = to_send.unsqueeze(0) + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + block_seq_len = to_send.shape[1] // 2 + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + to_send_slice = to_send[:, block_seq_len:].contiguous() + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + _ops.append(send_op) + + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + _ops.append(recv_op) + + # response = dist.dist.batch_isend_irecv(_ops) + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + return to_send_f if orig_ndim != 3 else to_send_f.squeeze(0) + +@lru_cache(maxsize=16) +def calc_chunks(cu_seqlen, moba_chunk_size): + """calc chunks that needs moba attention""" + + # batch_sizes[batch_idx] = batch size ( seqlen ) of batch idx + # example: [seq_len] + batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] + + # batch_num_chunk[batch_idx] = how many chunk in batch idx + # example: [number of all chunks with chunk size equal to moba_chunk_size + 1 (the one with smaller size)] + batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size + + # cu_num_chunk[batch_idx] = first chunk id of this batch + # example: [1, 1] + cu_num_chunk = torch.ones( + batch_num_chunk.numel() + 1, + device=cu_seqlen.device, + dtype=batch_num_chunk.dtype, + ) + # example: [1, 1 + num of chunks] + cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0) + + # total chunk ( for all batch ) + # example: 1 + num of chunks + num_chunk = cu_num_chunk[-1] + + # chunk_sizes[chunk_idx] = chunk_size of chunk idx + chunk_sizes = torch.full( + (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device + ) + chunk_sizes[0] = 0 # for calc cu chunk + batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size + chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size + # example chunk_sizes: [0, moba_chunk_size, ..., moba_chunk_size, batch_last_chunk_size] + + + # cu_chunk[chunk_idx] = the start chunk offset of chunk idx + # example: [0, moba_chunk_size, ..., seq_len] + cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32) + + + # chunk_to_batch[chunk_idx] = batch idx of the chunk idx + # example: [0, 0, 0, ...., 0] + chunk_to_batch = torch.zeros( + (num_chunk,), dtype=torch.int32, device=cu_seqlen.device + ) + + # example: [0, 0, 0, ... , 0] (if there are multiple samples in the batch, the index of the starting chunk of each batch from the 1st sample will be 1) + # but if there is only one batch, cu_num_chunk[1:-1] will be empty and no element will be assigned with 1 (all correspond to 0-th sample) + chunk_to_batch[cu_num_chunk[1:-1]] = 1 + + # example: [0, 0, 0, ..., 0] + chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32) + + """ filter chunks that need moba attn """ + # filter chunks ( remove last chunk of each batch ) + # filtered_chunk_indices: chunk index list that excludes the last chunk of each batch + chunk_to_remove = cu_num_chunk[1:] - 1 # example: number of chunks (num_chunk - 1) + # print(f"calc_chunks | chunk_to_remove: {chunk_to_remove}") + + chunk_to_remain = torch.ones( + (num_chunk, ), dtype=torch.bool, device=cu_seqlen.device + ) + chunk_to_remain[chunk_to_remove] = False # example: + filtered_chunk_indices = chunk_to_remain.nonzero(as_tuple=True)[0] + num_filtered_chunk = len(filtered_chunk_indices) + + return ( + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch, + ) + + +def compute_moba_gate( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offset: torch.Tensor, + cu_seqlens: torch.Tensor, + moba_chunk_size: int, + moba_topk: int, +): + if len(q.shape) == 4: + q, k, v = \ + tensor_4d_to_3d(q), \ + tensor_4d_to_3d(k), \ + tensor_4d_to_3d(v) + seq_offset: int = seq_offset.detach().cpu().item() + seqlen_block, num_head, head_dim = q.shape + _, k_num_head, _ = k.shape + if num_head > k_num_head: + k = torch.repeat_interleave(k, num_head // k_num_head, dim=1) + v = torch.repeat_interleave(v, num_head // k_num_head, dim=1) + + # --------------------------------------------------------------------------------------------- + kv = torch.stack((k, v), dim=1) # [ blk_S, 2, H, D ] + + world_size = dist.get_world_size() + kv_list = [torch.zeros_like(kv, dtype=q.dtype, device=q.device) for _ in range(world_size)] + dist.all_gather(kv_list, kv) + kv_gathered = torch.cat(kv_list, dim=0) # [ S, 2, H, D ] + + + """ some basic variables """ + # qkv shape = [ S, H, D ] + block_size = q.shape[0] + seqlen, _, num_head, head_dim = kv_gathered.shape + + """ prepare chunk meta """ + ( + cu_chunk, # example: [0, moba_chunk_size, ..., seq_len] + filtered_chunk_indices, # example: [0, 1, 2, ..., num_filtered_chunk-1] (i.e. except the last chunk) + num_filtered_chunk, # example: num_filtered_chunk + chunk_to_batch, # example: [0, 0, ... ,0] (for batch_size=1) with size 1 + real num of chunks + ) = calc_chunks(cu_seqlens, moba_chunk_size) + + # we will adjust selective topk to moba_topk - 1, as the last chunk is always chosen + moba_topk = min(moba_topk - 1, num_filtered_chunk) + assert moba_topk > 0, "moba_topk should be greater than 0" + + # filtered_kv is a dense matrix that only contains filtered chunk of kv + filtered_kv_indices = torch.arange( + 0, moba_chunk_size, dtype=torch.int32, device=q.device + )[None, :].repeat(num_filtered_chunk, 1) + filtered_kv_indices += cu_chunk[filtered_chunk_indices][:, None] + + # select the elements of KV corresponding to all chunks that are not filtered out + filtered_kv = kv_gathered.index_select(0, filtered_kv_indices.view(-1)) + + """ calc key_gate_weight and gate """ + # key_gate_weight [ F_N_CHUNK, HEAD, HEAD_DIM ] + key_gate_weight = ( + filtered_kv[:, 0] # K + .view(num_filtered_chunk, moba_chunk_size, num_head, head_dim) + .mean(dim=1) # mean pooling along chunk size + .float() + ) + # print(f"Rank {dist.get_rank()} | compute_moba_gate | key_gate_weight shape: {key_gate_weight.shape}") + + q = q.type(torch.float32) # float logit on the fly for better gate logit perception + key_gate_weight = key_gate_weight.type( + torch.float32 + ) # float logit for better gate logit perception + gate = torch.einsum( + "nhd,shd->nhs", key_gate_weight, q + ) # gate [ F_N_CHUNK, HEAD, SEQ_BLOCK] + key_gate_weight = key_gate_weight.type_as(k) + q = q.type_as(k) + + # pose process gate, masking unchosen batch and apply causal mask to current chunk + gate_seq_idx = torch.arange( + seq_offset, min(seq_offset + block_size, seqlen), device=q.device, dtype=torch.int32 + )[None, :].repeat(num_filtered_chunk, 1) + chunk_end = cu_chunk[filtered_chunk_indices + 1] + batch_end = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1] + gate_chunk_end_mask = gate_seq_idx < chunk_end[:, None] + gate_batch_end_mask = gate_seq_idx >= batch_end[:, None] + gate_inf_mask = gate_chunk_end_mask | gate_batch_end_mask + gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf")) + # print(f"Rank {dist.get_rank()} | compute_moba_gate | gate shape before topK: {gate.shape}") + + """ find moba q that needs moba attn """ + # find topk chunks + # gate_top_k_idx with shape [TOP_K, HEAD, SEQ_BLOCK] + _, gate_top_k_idx = torch.topk(gate, k=moba_topk, dim=0, largest=True, sorted=False) + # apply causal mask + gate_mask = torch.logical_not(gate.isinf()) + + # select topk chunks + gate_idx_mask = torch.zeros(gate_mask.shape, dtype=torch.bool, device=q.device) + gate_idx_mask = gate_idx_mask.scatter_(dim=0, index=gate_top_k_idx, value=True) + + # [ F_N_CHUNK, HEAD, SEQ_BLOCK] + gate_mask = torch.logical_and(gate_mask, gate_idx_mask).contiguous() + + return ( + # gate_mask does not need to be gathered because + # each device only needs the gate_mask corresponding to the current query block + gate_mask, + cu_chunk, + filtered_chunk_indices, + num_filtered_chunk, + chunk_to_batch + ) diff --git a/minference/ops/op_utils/vertical_slash_utils.py b/minference/ops/op_utils/vertical_slash_utils.py new file mode 100644 index 0000000..eba1700 --- /dev/null +++ b/minference/ops/op_utils/vertical_slash_utils.py @@ -0,0 +1,770 @@ +import os +from typing import List + +import torch +import torch.distributed as dist + +import triton +import triton.language as tl + + +@triton.jit +def _triton_extract_kv_kernel( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + stride_lz, stride_ln, stride_lh, stride_ld, + stride_bz, stride_bn, stride_bh, stride_bd, + stride_iz, stride_ih, stride_in, + stride_cz, stride_ch, stride_cr, + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_n = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + min_n = tl.load(v_cnt_ptr + step * stride_cr) + max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) + start_n = start_n * BLOCK_N + end_n = start_n + BLOCK_N + if start_n >= max_n or end_n <= min_n: + return + + offs_d = tl.arange(0, BLOCK_D) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = (offs_n >= min_n) & (offs_n < max_n) + + v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + + # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens + idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens + k = tl.load(local_k_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) + v = tl.load(local_v_ptr + idx[:, None] * stride_ln, mask=mask_n[:, None], other=0.) + tl.store(bar_k_ptr + offs_n[:, None] * stride_bn, k, mask=mask_n[:, None]) + tl.store(bar_v_ptr + offs_n[:, None] * stride_bn, v, mask=mask_n[:, None]) + + +def extract_kv( + local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + step: int, +): + batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape + _, num_tokens, num_kv_heads, _ = local_k.shape + block_N = 128 + block_D = head_dim + _triton_extract_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), + bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N=block_N, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + + +@triton.jit +def _triton_merge_kv_kernel( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + stride_lz, stride_ln, stride_lh, stride_ld, + stride_bz, stride_bn, stride_bh, stride_bd, + stride_iz, stride_ih, stride_in, + stride_cz, stride_ch, stride_cr, + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, +): + start_n = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + v_cnt_ptr = v_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + min_n = tl.load(v_cnt_ptr + step * stride_cr) + max_n = tl.load(v_cnt_ptr + (step + 1) * stride_cr) + start_n = start_n * BLOCK_N + end_n = start_n + BLOCK_N + if start_n >= max_n or end_n <= min_n: + return + + offs_d = tl.arange(0, BLOCK_D) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = (offs_n >= min_n) & (offs_n < max_n) + + v_idx_ptr = v_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + local_k_ptr = local_k + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + local_v_ptr = local_v + batch_idx * stride_lz + kv_head_idx * stride_lh + offs_d[None, :] * stride_ld + bar_k_ptr = bar_k + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + bar_v_ptr = bar_v + batch_idx * stride_bz + qo_head_idx * stride_bh + offs_d[None, :] * stride_bd + + # idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) - step * num_tokens + idx = tl.load(v_idx_ptr + offs_n * stride_in, mask=mask_n, other=0) % num_tokens + k = tl.load(bar_k_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_k.type.element_ty) + v = tl.load(bar_v_ptr + offs_n[:, None] * stride_bn, mask=mask_n[:, None], other=0.).to(local_v.type.element_ty) + tl.atomic_add(local_k_ptr + idx[:, None] * stride_ln, k, mask=mask_n[:, None], sem="relaxed") + tl.atomic_add(local_v_ptr + idx[:, None] * stride_ln, v, mask=mask_n[:, None], sem="relaxed") + + +def merge_kv( + local_k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + local_v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + bar_k: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + bar_v: torch.Tensor, # [batch_size, max_v_size, num_qo_heads, head_dim] + v_idx: torch.Tensor, # [batch_size, num_qo_heads, max_v_size] + v_cnt: torch.Tensor, # [batch_size, num_qo_heads, world_size + 1] + step: int, +): + batch_size, max_v_size, num_qo_heads, head_dim = bar_k.shape + _, num_tokens, num_kv_heads, _ = local_k.shape + block_N = 128 + block_D = head_dim + _triton_merge_kv_kernel[(triton.cdiv(max_v_size, block_N), num_qo_heads, batch_size)]( + local_k, local_v, bar_k, bar_v, v_idx, v_cnt, + local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3), + bar_k.stride(0), bar_k.stride(1), bar_k.stride(2), bar_k.stride(3), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + step, num_tokens, num_qo_heads, num_kv_heads, + BLOCK_N=block_N, BLOCK_D=block_D, + num_warps=4, num_stages=1, + ) + + +# triton.cdiv(world_size * num_blocks, BLOCK_N), num_heads, batch_size +# block_mask: [batch_size, num_heads, num_blocks_global] +@triton.jit +def _calc_block_mask_kernel( + s_idx, block_mask, + stride_sz, stride_sh, stride_sk, + stride_bz, stride_bh, stride_bn, + max_s_size, num_tokens, granularity, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + group_idx = tl.program_id(0) + + block_offs = tl.arange(0, BLOCK_N) + slash_offs = tl.arange(0, BLOCK_K) + + s_idx_ptr = s_idx + batch_idx * stride_sz + head_idx * stride_sh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx = group_idx * BLOCK_N + block_offs + + blocks = tl.zeros([BLOCK_N], dtype=tl.uint8) + for s_off in range(0, max_s_size, BLOCK_K): + s = tl.load(s_idx_ptr + (s_off + slash_offs) * stride_sk) + left = (num_tokens - granularity - s) // granularity + right = (num_tokens - 1 - s) // granularity + + # mask is generated by checking if a block's index falls between the calculated ranges + blocks |= tl.max((block_idx[None, :] >= left[:, None]) & (block_idx[None, :] <= right[:, None]), 0).to(tl.uint8) + + b_mask = (group_idx * BLOCK_N + block_offs) * granularity < num_tokens + tl.store(block_mask_ptr + (group_idx * BLOCK_N + block_offs) * stride_bn, blocks, mask=b_mask) + + +@triton.jit +def _striped_convert_indices_kernel( + last_row_mask, v_idx, v_cnt, + block_mask, bar_idx, bar_pos, bar_cnt, + stride_rz, stride_rh, stride_rn, + stride_vz, stride_vh, stride_vk, + stride_nz, stride_nh, stride_nt, + stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, + stride_iz, stride_ih, stride_im, stride_ik, + stride_cz, stride_ch, stride_cm, stride_ct, + max_v_size, num_blocks, granularity, world_size, rank, + BLOCK_N: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + block_idx_q_local = tl.program_id(0) + + block_idx_q_global = block_idx_q_local * world_size + rank + + num_tokens_local = num_blocks * granularity + num_blocks_global = world_size * num_blocks + shift = num_blocks_global - 1 - block_idx_q_global + + block_offs = tl.arange(0, BLOCK_N) + + last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh + v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh + v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm + bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm + + cnt_valid = 0 + cnt_all = 0 + v_off = 0 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + for step in range(world_size): + for block_off_k in range(0, num_blocks, BLOCK_N): + block_idx_k_local = block_off_k + block_offs + block_idx_k_global = (block_off_k + block_offs) * world_size + step + mask_local = tl.load( + last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, + mask=(block_idx_k_global + shift < num_blocks_global), + other=0, + ) + tl.store( + block_mask_ptr + block_idx_k_local * stride_bn, + mask_local, + mask=(block_idx_k_local < num_blocks), + ) + block_left = v_off + block_idx_k_local * granularity + block_right = block_left + granularity + max_blocks = block_idx_q_local + 1 if step <= rank else block_idx_q_local + v_max = v_off + min(block_off_k + BLOCK_N, max_blocks) * granularity + while v < v_max and cnt_all < max_v_size: + if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): + tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) + tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) + cnt_valid += 1 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + block_mask_ptr += stride_bt + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + v_off += num_tokens_local + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + +@triton.jit +def _zigzag_convert_indices_kernel( + last_row_mask, v_idx, v_cnt, + block_mask, bar_idx, bar_pos, bar_cnt, + stride_rz, stride_rh, stride_rn, + stride_vz, stride_vh, stride_vk, + stride_nz, stride_nh, stride_nt, + stride_bt, stride_bz, stride_bh, stride_bm, stride_bn, + stride_iz, stride_ih, stride_im, stride_ik, + stride_cz, stride_ch, stride_cm, stride_ct, + max_v_size, num_blocks, granularity, world_size, rank, + BLOCK_N: tl.constexpr, +): + batch_idx = tl.program_id(2) + head_idx = tl.program_id(1) + block_idx_q_local = tl.program_id(0) + + if rank < world_size // 2: + revert_rank = rank * 2 + else: + revert_rank = (world_size - 1 - rank) * 2 + 1 + if block_idx_q_local < num_blocks // 2: + block_idx_q_global = revert_rank * (num_blocks // 2) + block_idx_q_local + else: + block_idx_q_global = (world_size * 2 - 1 - revert_rank) * (num_blocks // 2) + block_idx_q_local - (num_blocks // 2) + + num_blocks_global = world_size * num_blocks + shift = num_blocks_global - 1 - block_idx_q_global + + block_offs = tl.arange(0, BLOCK_N) + + last_row_mask_ptr = last_row_mask + batch_idx * stride_rz + head_idx * stride_rh + v_idx_ptr = v_idx + batch_idx * stride_vz + head_idx * stride_vh + v_cnt_ptr = v_cnt + batch_idx * stride_nz + head_idx * stride_nh + block_mask_ptr = block_mask + batch_idx * stride_bz + head_idx * stride_bh + block_idx_q_local * stride_bm + bar_idx_ptr = bar_idx + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_pos_ptr = bar_pos + batch_idx * stride_iz + head_idx * stride_ih + block_idx_q_local * stride_im + bar_cnt_ptr = bar_cnt + batch_idx * stride_cz + head_idx * stride_ch + block_idx_q_local * stride_cm + + cnt_valid = 0 + cnt_all = 0 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + for step in range(world_size): + v_off = step * num_blocks * granularity + v_end = v_off + num_blocks * granularity + for block_off_k in range(0, num_blocks, BLOCK_N): + block_idx_k_local = block_off_k + block_offs + # assert BLOCK_N <= num_blocks // 2 + if block_off_k < num_blocks // 2: + v_off_global = step * (num_blocks // 2) * granularity + block_idx_k_global = step * (num_blocks // 2) + block_idx_k_local + else: + v_off_global = (world_size * 2 - 2 - step) * (num_blocks // 2) * granularity + block_idx_k_global = (world_size * 2 - 1 - step) * (num_blocks // 2) + block_idx_k_local - (num_blocks // 2) + mask_local = tl.load( + last_row_mask_ptr + (block_idx_k_global + shift) * stride_rn, + mask=(block_idx_k_global + shift < num_blocks_global), + other=0, + ) + tl.store( + block_mask_ptr + block_idx_k_local * stride_bn, + mask_local, + mask=(block_idx_k_local < num_blocks), + ) + # block_left = block_idx_k_global * granularity - v_off_global + v_off + # block_right = block_left + granularity + block_left = v_off + block_idx_k_local * granularity + block_right = block_left + granularity + v_max = (block_idx_q_global + 1) * granularity - v_off_global + v_off + while v < v_end and cnt_all <= max_v_size: + if v < v_max: + if tl.max(((v >= block_left) & (v < block_right)) & (~mask_local), 0): + tl.store(bar_idx_ptr + cnt_valid * stride_ik, v - v_off) + tl.store(bar_pos_ptr + cnt_valid * stride_ik, cnt_all - 1) + cnt_valid += 1 + v = tl.load(v_idx_ptr + cnt_all * stride_vk) + cnt_all += 1 + block_mask_ptr += stride_bt + tl.store(bar_cnt_ptr, cnt_valid) + bar_cnt_ptr += stride_ct + if block_idx_q_local == tl.num_programs(0) - 1: + tl.store(v_cnt_ptr, cnt_all - 1) + v_cnt_ptr += stride_nt + + +def convert_indices( + v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] + s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] + world_size: int, + rank: int, + num_blocks: int, + granularity: int, + num_tokens: int = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, +): + num_blocks_global = world_size * num_blocks + if num_tokens is None: + # Note that for each invokation of `convert_indices`, `num_tokens` is None and becomes the **global number of tokens** + num_tokens = num_blocks_global * granularity + batch_size, num_heads, max_v_size = v_idx.shape + batch_size, num_heads, max_s_size = s_idx.shape + last_row_mask = torch.zeros((batch_size, num_heads, num_blocks_global), dtype=torch.bool, device=s_idx.device) + + BLOCK_N, BLOCK_K = 128, 128 + assert max_s_size <= BLOCK_K * BLOCK_K, f"max_s_size={max_s_size} > BLOCK_K * BLOCK_K={BLOCK_K * BLOCK_K}" + _calc_block_mask_kernel[(triton.cdiv(num_blocks_global, BLOCK_N), num_heads, batch_size)]( + s_idx, last_row_mask, + s_idx.stride(0), s_idx.stride(1), s_idx.stride(2), + last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), + max_s_size, num_tokens, granularity, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + num_warps=4, num_stages=2, + ) + + block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) + bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) + bar_cnt = torch.empty((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) + v_cnt = torch.empty((batch_size, num_heads, world_size + 1), dtype=torch.int32, device=v_idx.device) + bar_pos = torch.zeros_like(bar_idx) + if zigzag_transform: + convert_indices_kernel = _zigzag_convert_indices_kernel + assert num_blocks % 2 == 0 + BLOCK_N = max(num_blocks // 2, 128) + else: + convert_indices_kernel = _striped_convert_indices_kernel + BLOCK_N = 128 + convert_indices_kernel[(num_blocks, num_heads, batch_size)]( + last_row_mask, v_idx, v_cnt, block_mask, bar_idx, bar_pos, bar_cnt, + last_row_mask.stride(0), last_row_mask.stride(1), last_row_mask.stride(2), + v_idx.stride(0), v_idx.stride(1), v_idx.stride(2), + v_cnt.stride(0), v_cnt.stride(1), v_cnt.stride(2), + block_mask.stride(0), block_mask.stride(1), block_mask.stride(2), block_mask.stride(3), block_mask.stride(4), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + max_v_size, num_blocks, granularity, world_size, rank, BLOCK_N=BLOCK_N, + num_warps=1, num_stages=1, + ) + + return block_mask, bar_idx, bar_cnt, bar_pos, v_cnt + + +def _torch_convert_indices( + v_idx: torch.Tensor, # [batch_size, num_heads, max_v_size] + s_idx: torch.Tensor, # [batch_size, num_heads, max_s_size] + world_size: int, + rank: int, + num_blocks: int, + granularity: int, +): + batch_size, num_heads, max_v_size = v_idx.shape + num_tokens = world_size * num_blocks * granularity + block_mask = torch.zeros((world_size, batch_size, num_heads, num_blocks, num_blocks), dtype=torch.bool, device=v_idx.device) + bar_idx = torch.zeros((batch_size, num_heads, num_blocks, max_v_size), dtype=torch.int32, device=v_idx.device) + bar_cnt = torch.zeros((batch_size, num_heads, num_blocks, world_size + 1), dtype=torch.int32, device=v_idx.device) + for batch_idx in range(batch_size): + for head_idx in range(num_heads): + for block_idx_q in range(num_blocks): + block_idx_q_global = block_idx_q * world_size + rank + cnt_all, cnt_valid = 0, 0 + for step in range(world_size): + for block_idx_k in range(block_idx_q + 1): + block_idx_k_global = block_idx_k * world_size + step + s_min = max((block_idx_q_global - block_idx_k_global - 1) * granularity, 0) + s_max = (block_idx_q_global - block_idx_k_global + 1) * granularity + flag = torch.any((s_idx[batch_idx, head_idx] > s_min) & (s_idx[batch_idx, head_idx] < s_max)) + block_mask[step, batch_idx, head_idx, block_idx_q, block_idx_k] = flag + v_min = (step * num_blocks + block_idx_k) * granularity + max_blocks = block_idx_q + 1 if step <= rank else block_idx_q + v_max = (step * num_blocks + min(block_idx_k + 1, max_blocks)) * granularity + while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_min: + cnt_all += 1 + while cnt_all < max_v_size and v_idx[batch_idx, head_idx, cnt_all] < v_max: + if not flag: + bar_idx[batch_idx, head_idx, block_idx_q, cnt_valid] = \ + v_idx[batch_idx, head_idx, cnt_all] - step * num_blocks * granularity + cnt_valid += 1 + cnt_all += 1 + bar_cnt[batch_idx, head_idx, block_idx_q, step + 1] = cnt_valid + return block_mask, bar_idx, bar_cnt + + + +def sum_all_diagonal_matrix(mat: torch.Tensor): + b, h, m, n = mat.shape + + # Pads the matrix on left and right (on the last dimension) + mat_padded = torch.nn.functional.pad(mat, (m, m), "constant", 0.) # shape: [b, h, m, 2 * m + n] + # Change the strides + mat_strided = mat_padded.as_strided((b, h, m, m + n), (m * (2 * m + n) * h, m * (2 * m + n), 2 * m + n + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 2) # shape: [b, h, m + n] + return sum_diags[:, :, 1:].contiguous() + +def calc_index( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + last_q_size: int = 64, + sink_tokens: int = 30, + sliding_window: int = 100, + group: dist.group = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, + granularity: int = 128, +): + # TODO: adapt naturely striped inputs + # TODO: flex-prefill (top-P) + # TODO: reduce bubble + # TODO: support total_num_tokens % world_size != 0 + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): + flex_prefill = True + v_p = [x[0] for x in v_size] + v_size = [x[1] for x in v_size] + s_p = [x[0] for x in s_size] + s_size = [x[1] for x in s_size] + else: + flex_prefill = False + assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) + + max_v_size = min(triton.cdiv(max(v_size), 128), num_tokens // 128) * 128 + max_s_size = min(triton.cdiv(max(s_size), 128), num_tokens // 128) * 128 + + last_rank = world_size - 1 + if rank == last_rank: + last_q = q[:, -last_q_size:, :, :].detach().clone().reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)) + else: + last_q = torch.zeros((batch_size, last_q_size, num_kv_heads, num_qo_heads // num_kv_heads, head_dim), device=q.device, dtype=q.dtype) + + dist.broadcast(last_q, src=last_rank, group=group, async_op=False) + + qk = torch.einsum('bmghd, bngd -> bghmn', last_q, k) * (k.shape[-1] ** -0.5) + qk = qk.reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) + + if rank == last_rank: + # Causal Mask, requires num_tokens // world_size >= last_q + arange = torch.arange(last_q_size, device=k.device) + mask = arange[None, None, :, None] >= arange[None, None, None, :] + qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) + if flex_prefill: # qk = torch.softmax(qk, dim=-1) / last_q_size + qk_max = torch.max(qk, dim=-1, keepdim=True).values + qk_max_list = [torch.empty_like(qk_max) for _ in range(world_size)] + dist.all_gather(qk_max_list, qk_max, group=group, async_op=False) + qk_max = torch.max(torch.stack(qk_max_list), dim=0).values + qk = torch.exp(qk - qk_max) + qk_sum = torch.sum(qk, dim=-1, keepdim=True) + qk_sum_list = [torch.empty_like(qk_sum) for _ in range(world_size)] + dist.all_gather(qk_sum_list, qk_sum, group=group, async_op=False) + qk_sum = torch.sum(torch.stack(qk_sum_list), dim=0) + qk /= (qk_sum * last_q_size) + + v_gather_rank = 0 + vertical = qk.sum(-2, keepdim=False) # [B, H, N_LOCAL] + if rank == 0 and not flex_prefill: + vertical[..., :sink_tokens] = torch.inf + if rank == v_gather_rank: + gathered_vertical = [torch.empty_like(vertical) for _ in range(world_size)] + else: + gathered_vertical = None + dist.gather(vertical, gathered_vertical, dst=v_gather_rank, group=group, async_op=False) + + if rank == v_gather_rank: + vertical: torch.Tensor = torch.cat(gathered_vertical, dim=-1) + if stripe_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, -1, world_size, granularity)) + vertical = vertical.swapaxes(2, 3) + vertical = vertical.reshape((batch_size, num_qo_heads, -1)) + elif zigzag_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, 2, world_size, -1)) + chunks = [] + for step in range(world_size): + chunks.append(vertical[:, :, 0, step]) + chunks.append(vertical[:, :, 1, world_size - 1 - step]) + vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) + + v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) + v_indices = v_topk.indices.to(torch.int32) + if flex_prefill: + v_cumsum = v_topk.values.cumsum_(dim=-1) + v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + v_size = torch.tensor(v_size, device=k.device)[None, :, None] + v_arange = torch.arange(max_v_size, device=k.device) + v_indices.masked_fill_(v_arange[None, None, :] >= v_size, num_tokens * world_size) + v_indices = v_indices.sort(dim=-1, descending=False).values + else: + v_indices = torch.empty((batch_size, num_qo_heads, max_v_size), dtype=torch.int32, device=k.device) + dist.broadcast(v_indices, src=v_gather_rank, group=group, async_op=False) # async + + s_gather_rank = 0 + slash = sum_all_diagonal_matrix(qk) # shape: [B, H, N_LOCAL + LAST_Q_SIZE - 1] + if rank == world_size - 1 and not flex_prefill: + # -> index starting from the left bottom corner to right upper corner + # (sliding_window) from -(last_q_size-1) is the sliding window close to the main diagonal + slash[..., -(last_q_size - 1 + sliding_window):] = torch.inf + + + if rank == s_gather_rank: + gathered_slash = [torch.empty_like(slash) for _ in range(world_size)] + else: + gathered_slash = None + dist.gather(slash, gathered_slash, dst=s_gather_rank, group=group, async_op=False) + + if rank == s_gather_rank: + slash = gathered_slash[0] + for next_slash in gathered_slash[1:]: + slash[..., -last_q_size + 1:] += next_slash[..., :last_q_size - 1] + slash = torch.cat((slash, next_slash[..., last_q_size - 1:]), dim=-1) + + # slash presents the sum of attention from 0-th to (num_tokens_global - last_q_size - 1), where 0 represents the diagonal at bottom left corner + slash = slash[..., :-last_q_size + 1] + s_topk = torch.topk(slash, max_s_size, -1, sorted=True) + + # s_indices contain indices starting from the right upper corner to left bottom corner + s_indices = (num_tokens * world_size - 1) - s_topk.indices.to(torch.int32) + if flex_prefill: + s_cumsum = s_topk.values.cumsum_(dim=-1) + s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + s_size = torch.tensor(s_size, device=k.device)[None, :, None] + s_arange = torch.arange(max_s_size, device=k.device) + s_indices.masked_fill_(s_arange[None, None, :] >= s_size, -1) + s_indices = s_indices.sort(dim=-1, descending=True).values + else: + s_indices = torch.empty((batch_size, num_qo_heads, max_s_size), dtype=torch.int32, device=k.device) + dist.broadcast(s_indices, src=s_gather_rank, group=group, async_op=False) + + return v_indices.to(torch.int32), s_indices.to(torch.int32) + +def calc_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + last_q_size: int = 64, + sink_tokens: int = 30, + sliding_window: int = 100, + group: dist.group = None, + stripe_transform: bool = False, + zigzag_transform: bool = False, + granularity: int = 128, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + + if all([type(x) is list for x in v_size]) and all([type(x) is list for x in s_size]): + flex_prefill = True + v_p = [x[0] for x in v_size] + v_size = [x[1] for x in v_size] + s_p = [x[0] for x in s_size] + s_size = [x[1] for x in s_size] + else: + flex_prefill = False + assert all([type(x) is int for x in v_size]) and all([type(x) is int for x in s_size]) + + qk = torch.einsum( + f'bmghd, bngd -> bghmn', + q[:, -last_q_size:, :, :].reshape((batch_size, last_q_size, num_kv_heads, -1, head_dim)), + k, + ).reshape((batch_size, num_qo_heads, last_q_size, num_tokens)) * (head_dim ** -0.5) + + arange = torch.arange(last_q_size, device=k.device) + mask = arange[None, None, :, None] >= arange[None, None, None, :] + qk[:, :, :, -last_q_size:] = torch.where(mask, qk[:, :, :, -last_q_size:], -torch.inf) + if flex_prefill: + qk = torch.softmax(qk, dim=-1) / last_q_size + + max_v_size = min(max(v_size), num_tokens) + max_v_size = triton.cdiv(max_v_size, 128) * 128 + vertical = qk.sum(-2, keepdim=False) + if not flex_prefill: + vertical[..., :sink_tokens] = torch.inf + if stripe_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, -1, dist.get_world_size(group), granularity)) + vertical = vertical.swapaxes(2, 3) + vertical = vertical.reshape((batch_size, num_qo_heads, -1)) + elif zigzag_transform: + vertical = vertical.reshape((batch_size, num_qo_heads, 2, dist.get_world_size(group), -1)) + chunks = [] + for step in range(dist.get_world_size(group)): + chunks.append(vertical[:, :, 0, step]) + chunks.append(vertical[:, :, 1, dist.get_world_size(group) - 1 - step]) + vertical = torch.concat(chunks, dim=2).reshape((batch_size, num_qo_heads, -1)) + v_topk = torch.topk(vertical, max_v_size, -1, sorted=True) + v_indices = v_topk.indices + if flex_prefill: + v_cumsum = v_topk.values.cumsum_(dim=-1) + v_size = (v_cumsum < torch.tensor(v_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + v_size = torch.tensor(v_size, device=k.device)[None, :, None] + + max_s_size = min(max(s_size), num_tokens) + max_s_size = triton.cdiv(max_s_size, 128) * 128 + slash = sum_all_diagonal_matrix(qk)[..., :-last_q_size + 1] + if not flex_prefill: + slash[..., -sliding_window:] = torch.inf + s_topk = torch.topk(slash, max_s_size, -1, sorted=True) + s_indices = (num_tokens - 1) - s_topk.indices + if flex_prefill: + s_cumsum = s_topk.values.cumsum_(dim=-1) + s_size = (s_cumsum < torch.tensor(s_p, device=k.device)[None, :, None]).sum(dim=-1, keepdim=True) + else: + s_size = torch.tensor(s_size, device=k.device)[None, :, None] + + v_arange = torch.arange(max_v_size, device=k.device) + v_idx = v_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) + v_idx.masked_fill_(v_arange[None, None, :] >= v_size, 2147483647) + v_idx = v_idx.sort(dim=-1, descending=False).values + + s_arange = torch.arange(max_s_size, device=k.device) + s_idx = s_indices.to(torch.int32).reshape((batch_size, num_qo_heads, -1)) + s_idx.masked_fill_(s_arange[None, None, :] >= s_size, -1) + s_idx = s_idx.sort(dim=-1, descending=True).values + + return v_idx, s_idx + +def build_index_local( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, + granularity: int, + world_size: int = 1, + rank: int = 0, +): + if type(v_size) is list: + assert len(v_size) == q.shape[2] + assert len(s_size) == q.shape[2] + v_idx, s_idx = calc_index_local(q, k, v_size, s_size, last_q_size=64) + else: + v_idx, s_idx = v_size, s_size + + num_blocks = triton.cdiv(num_tokens, granularity) + block_mask, bar_idx, bar_cnt, _, _ = convert_indices(v_idx, s_idx, world_size, rank, num_blocks, granularity) + block_mask = block_mask[rank] + return block_mask, bar_idx, bar_cnt + +def build_index( + q: torch.Tensor, + k: torch.Tensor, + v_size: List[int], + s_size: List[int], + num_tokens: int, # num_tokens_local + granularity: int, + stripe_transform: bool = True, + zigzag_transform: bool = False, + group: dist.group = None, +): + """ + Input: (all inputs correspond to the local part for each rank) + q: shape [batch_size, num_tokens_local, num_qo_heads, head_dim] + k: shape [batch_size, num_tokens_local, num_kv_heads, head_dim] + v_size: shape [num_qo_heads] + s_size: shape [num_qo_heads] + num_tokens: number of tokens in the local part of QK + Returns: + block_mask: shape [world_size, batch_size, num_heads, num_blocks, num_blocks] + bar_idx: shape [batch_size, num_heads, num_blocks, max_v_size] + bar_cnt: shape [batch_size, num_heads, num_blocks, world_size + 1], each entry is the cumulative number of selected bars corresponding a rank + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + if isinstance(v_size, list): + v_idx, s_idx = calc_index( + q, k, v_size, s_size, last_q_size=64, group=group, + stripe_transform=stripe_transform, + zigzag_transform=zigzag_transform, + granularity=granularity + ) + else: + v_idx, s_idx = v_size, s_size + + num_blocks = triton.cdiv(num_tokens, granularity) # num_blocks_local + + # Note that block_mask is a 5D tensor with shape [world_size, batch_size, num_heads, num_blocks, num_blocks] + # with each block_mask[i] is to a mask corresponding the num_tokens_local x num_tokens_local matmul for each step + block_mask, bar_idx, bar_cnt, bar_pos, v_cnt = convert_indices( + v_idx, s_idx, world_size, rank, num_blocks, granularity, + stripe_transform=stripe_transform, + zigzag_transform=zigzag_transform, + ) + return block_mask, bar_idx, bar_cnt, bar_pos, v_idx, v_cnt + +def convert_blockmask( + blockmask: torch.Tensor, # [world_size, batch_size, num_heads, num_blocks, num_blocks] + block_size_M: int, + block_size_N: int, +): + ratio = block_size_M // block_size_N + original_shape = blockmask.shape + blockmask = blockmask.to(dtype=torch.uint8) + blockmask = blockmask.unsqueeze(-1).tile([1] * len(original_shape) + [ratio]).reshape((*original_shape[:-1], -1)) + + # now block_mask is [world_size, batch_size, num_heads, num_blocks, num_blocks * ratio] + nonzero_val, nonzero_idx = blockmask.sort(dim=-1, stable=True, descending=True) + + nonzero_rowcnt = blockmask.sum(dim=-1, dtype=torch.int32) + return nonzero_idx.contiguous().to(dtype=torch.int32), nonzero_rowcnt.contiguous() + diff --git a/minference/ops/op_utils/xattn_utils.py b/minference/ops/op_utils/xattn_utils.py new file mode 100644 index 0000000..760e927 --- /dev/null +++ b/minference/ops/op_utils/xattn_utils.py @@ -0,0 +1,590 @@ +import torch +import triton +import triton.language as tl +import torch.distributed as dist + +LN2 = 1 / 1.4426950408889634 +def create_causal_mask(batch_size, head_num, block_size, block_num, divide_block_num): + """ + Creates a causal attention mask used in transformer-based models. + + Parameters: + - batch_size (int): The number of sequences in the batch. + - head_num (int): The number of attention heads. + - block_size (int): The size of each block in the sequence. + - block_num (int): The total number of blocks in the sequence. + - divide_block_num (int): The block index at which causality is applied. + + Returns: + - torch.Tensor: A mask tensor of shape (batch_size, head_num, block_size, total_size) + where total_size = block_size * block_num. The mask enforces causal attention by + setting certain positions to `-inf` to prevent information leakage from future tokens. + """ + divide_block_num += 1 + if divide_block_num < 1 or divide_block_num > block_num: + raise ValueError( + f"divide_block_num ({divide_block_num}) must be between 1 and block_num ({block_num})." + ) + + total_size = block_size * block_num + device = "cuda" + mask = torch.zeros(block_size, total_size, device=device) + if divide_block_num < block_num: + mask[:, divide_block_num * block_size :] = float("-inf") + + if divide_block_num - 1 < block_num: + start_col = (divide_block_num - 1) * block_size + end_col = start_col + block_size + upper_tri_mask = torch.triu( + torch.full((block_size, block_size), float("-inf"), device=device), + diagonal=1, + ) + mask[:, start_col:end_col] = upper_tri_mask + + mask = mask.unsqueeze(0).unsqueeze(0) + mask = mask.expand(batch_size, head_num, block_size, total_size) + return mask + +def find_blocks_chunked( + input_tensor: torch.Tensor, # (batch_size, num_heads, num_block_q, num_block_k) + current_index, # + threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True +): + """ + Finds and selects relevant blocks of attention for transformer-based models based on a + threshold or a predefined number of blocks. + + Parameters: + - input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, num_block_q, num_block_k). + - current_index (int): The current index in the sequence processing. + - threshold (float or None): A threshold value used to determine the minimum attention weight sum. + - num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval. + - decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode. + - mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'. + - causal (bool): If True, applies causal masking to prevent future information leakage. + + Returns: + - torch.Tensor: A boolean mask of shape (batch_size, head_num, num_block_q, num_block_k), + indicating which blocks should be attended to. + """ + assert threshold is None or num_to_choose is None + batch_size, head_num, num_block_q, num_block_k = input_tensor.shape + input_tensor = input_tensor.to(float) + + total_sum = input_tensor.sum(dim=-1, keepdim=True) + if isinstance(threshold, torch.Tensor): + threshold = threshold.to(float) + required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ).expand((batch_size, head_num, num_block_q, 1)).to(input_tensor.device) + else: + required_sum = total_sum * threshold + + + mask = torch.zeros_like(input_tensor, dtype=torch.bool) + mask[:, :, :, 0] = 1 + mask[:, :, :, current_index : current_index + num_block_q] = ( + torch.eye(num_block_q, device=mask.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, num_block_q, num_block_q) + ) + # Note that other_values only contains the values of the current block + # (the sink blocks and diagonal are filled with 0) + other_values = input_tensor.masked_fill(mask, 0) + + + # Get sorted values + sorted_values, _ = torch.sort(other_values, dim=-1, descending=True) + sorted_values = sorted_values.to(input_tensor.device) + sorted_values = torch.cat( + [ + torch.zeros( + (batch_size, head_num, num_block_q, 1), device=input_tensor.device + ), + torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True), # shape: (batch_size, head_num, num_block_q, 1) + sorted_values[:, :, :, :-2], # :-2 excludes the first and diagonal (which are marked 0 in other_values) + ], + dim=-1, + ) + + # Get sorted indices + # index will select the already-masked (sink and diagonal) at the beginning + _, index = torch.sort( + torch.where(mask, 100000 * (1 + input_tensor), input_tensor), + dim=-1, + descending=True, + ) + + # [batch_size, head_num, num_block_q, num_block_k] + cumulative_sum_without_self = torch.cat( + [ + torch.zeros( + (batch_size, head_num, num_block_q, 1), device=input_tensor.device + ), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + + # Mask for indices where cumulative sum is below the required threshold. + index_mask = cumulative_sum_without_self < required_sum + index = torch.where(index_mask, index, 0) + + mask = mask.view(batch_size, head_num * num_block_q, num_block_k) + index = index.view(batch_size, head_num * num_block_q, num_block_k) + mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True + mask = mask.view(batch_size, head_num, num_block_q, num_block_k) + + + assert bool((torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True) >= required_sum * 0.99).all()), \ + f"mask sum {torch.where(mask,input_tensor,0).sum(dim=-1, keepdim=True)} < required_sum {required_sum}" + + try: + if causal: + assert (~mask[:, :, :, current_index + num_block_q :]).all() + except: + mask[:, :, :, current_index + num_block_q :] = False + + if causal: + if decoding: + assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() + else: + lambda_mask = torch.zeros_like(input_tensor,dtype=bool,device=input_tensor.device) + lambda_mask[:,:,:,0] = 1 + lambda_mask[:,:,:,current_index:current_index+num_block_q] = torch.eye(num_block_q, device=lambda_mask.device).unsqueeze(0).unsqueeze(0).expand(1,head_num,num_block_q,num_block_q) + assert(torch.where(lambda_mask,mask,True).all()) + + return mask + + +def shuffle_zigzag_masks( + block_masks: torch.Tensor, # [batch_size, num_qo_heads, num_blocks_local, num_blocks] + process_group: dist.ProcessGroup = None + ): + dim = len(block_masks.shape) - 1 + if not block_masks.is_contiguous(): + block_masks = block_masks.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(block_masks) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_zigzag_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = block_masks.shape[dim] // 2 + left_slicer = [slice(None)] * dim + [slice(None, block_seq_len)] + right_slicer = [slice(None)] * dim + [slice(block_seq_len, None)] + to_send_slice = block_masks[right_slicer].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[right_slicer] = block_masks[left_slicer] + to_send_f[left_slicer] = res + else: # A: 0 1, -> 0 7 + to_send_f[left_slicer] = block_masks[left_slicer] + to_send_f[right_slicer] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + + + +@triton.jit +def softmax_fuse_block_sum_kernel_causal( + In, + Out, + scale, + input_stride_0, + input_stride_1, + input_stride_2, + output_stride_0, + output_stride_1, + output_stride_2, + real_q_len, + k_len, # we assume k_len is divisible by chunk size + chunk_start, + chunk_end, + segment_size: tl.constexpr, + block_size: tl.constexpr, +): + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + for iter in range(0, num_iters_before_causal): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + for iter in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) + X = tl.where(mask, X, -1.0e6) + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + l_i_inv = 1.0 / l_i + + sum_mask = offs_q[:, None] < real_q_len + + for iter in range(0, num_iters_before_causal): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + for iter in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) + X = tl.where(mask, X, -1.0e6) + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + for iter in range(num_iters_before_causal + 1, num_iters): + X = tl.zeros([segment_size // block_size], dtype=tl.float32) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + + +@triton.jit +def softmax_fuse_block_sum_kernel_non_causal( + In, + Out, + scale, + input_stride_0, + input_stride_1, + input_stride_2, + output_stride_0, + output_stride_1, + output_stride_2, + real_q_len, + k_len, # we assume k_len is divisible by chunk size + chunk_start, + chunk_end, + segment_size: tl.constexpr, + block_size: tl.constexpr, +): + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + for iter in range(0, num_iters): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + + m_i = m_new + + l_i_inv = 1.0 / l_i + + sum_mask = offs_q[:, None] < real_q_len + + for iter in range(0, num_iters): + X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) + +@triton.jit +def flat_group_gemm_kernel(Q, K, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_oz, stride_oh, stride_on, + chunk_start, chunk_end, + H: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * stride_qn + K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * stride_kn + + Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_qn + tl.arange(0, BLOCK_K)[None, :] + K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * stride_kn + tl.arange(0, BLOCK_K)[:, None] + + num_iters = HEAD_DIM // BLOCK_K + o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for iter in range(num_iters): + q = tl.load(Q_ptrs + iter * BLOCK_K) + k = tl.load(K_ptrs + iter * BLOCK_K) + o += tl.dot(q, k) + + O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + + tl.store(O_ptrs, o.to(Out.type.element_ty)) + +@triton.jit +def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_oz, stride_oh, stride_on, + chunk_start, chunk_end, + H: tl.constexpr, + STRIDE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + is_caual: tl.constexpr, + ): + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + if is_caual: + if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn + K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn + + Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1) + K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None] + + o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for iter in range(STRIDE): + q = tl.load(Q_ptrs - iter * stride_qn) + k = tl.load(K_ptrs + iter * stride_kn) + o += tl.dot(q, k) + + O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + + tl.store(O_ptrs, o.to(Out.type.element_ty)) + + +def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True): + batch_size, num_heads, q_len, k_len = attn_weights_slice.shape + assert q_len % reshaped_block_size == 0 + try: + assert k_len % segment_size == 0 + except: + assert False, f"xAttention error, k_len: {k_len}, segment size: {segment_size}" + assert segment_size % reshaped_block_size == 0 + assert attn_weights_slice.stride(-1) == 1 + + output = torch.empty((batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size), dtype=attn_weights_slice.dtype, device=attn_weights_slice.device) + + grid = (q_len // reshaped_block_size, num_heads, batch_size) + + if is_causal: + softmax_fuse_block_sum_kernel_causal[grid]( + attn_weights_slice, + output, + scale, + attn_weights_slice.stride(0), + attn_weights_slice.stride(1), + attn_weights_slice.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + real_q_len, + k_len, + chunk_start, + chunk_end, + segment_size, + reshaped_block_size, + ) + else: + softmax_fuse_block_sum_kernel_non_causal[grid]( + attn_weights_slice, + output, + scale, + attn_weights_slice.stride(0), + attn_weights_slice.stride(1), + attn_weights_slice.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + real_q_len, + k_len, + chunk_start, + chunk_end, + segment_size, + reshaped_block_size, + ) + + return output + +def flat_group_gemm(query_states, key_states, chunk_start, chunk_end): + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[2] + + output = torch.empty((batch_size, num_heads, q_len, kv_len), dtype=query_states.dtype, device=query_states.device) + BLOCK_M = 128 + BLOCK_N = 128 + BLOCK_K = 64 + + grid = (q_len // BLOCK_M, kv_len // BLOCK_N, batch_size * num_heads) + flat_group_gemm_kernel[grid]( + query_states, + key_states, + output, + query_states.stride(0), + query_states.stride(1), + query_states.stride(2), + key_states.stride(0), + key_states.stride(1), + key_states.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + chunk_start, + chunk_end, + num_heads, + head_dim, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + + return output + +def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True): + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[2] + + assert (key_states.shape[0] == batch_size) + assert (key_states.shape[1] == num_heads) + assert (key_states.shape[3] == head_dim) + + output = torch.empty((batch_size, num_heads, q_len // stride, kv_len // stride), dtype=query_states.dtype, device=query_states.device) + BLOCK_M = 128 + BLOCK_N = 128 + assert (q_len % (stride * BLOCK_M) == 0), f"q_len={q_len}, stride={stride}, BLOCK_M={BLOCK_M}" + assert (kv_len % (stride * BLOCK_N) == 0), f"kv_len={kv_len}, stride={stride}, BLOCK_N={BLOCK_N}" + + grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) + flat_group_gemm_fuse_reshape_kernel[grid]( + query_states, + key_states, + output, + query_states.stride(0), + query_states.stride(1), + query_states.stride(2), + key_states.stride(0), + key_states.stride(1), + key_states.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + chunk_start, + chunk_end, + num_heads, + stride, + head_dim, + BLOCK_M, + BLOCK_N, + is_causal, + ) + + return output \ No newline at end of file diff --git a/minference/ops/pit_sparse_flash_attention_v3.py b/minference/ops/pit_sparse_flash_attention_v3.py new file mode 100644 index 0000000..ce59bd9 --- /dev/null +++ b/minference/ops/pit_sparse_flash_attention_v3.py @@ -0,0 +1,1490 @@ +import os +import sys +import math + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +import triton +import triton.language as tl + +from typing import List, Tuple + +# Save current flags +if torch.version.hip is None: + original_flags = sys.getdlopenflags() + try: + sys.setdlopenflags(os.RTLD_LAZY | os.RTLD_GLOBAL) + import block_sparse_attn_cuda + from block_sparse_attn.block_sparse_attn_interface import convert_blockmask_row_reverse, convert_blockmask_col_reverse + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + except ModuleNotFoundError as e: + print(f"[Warning] Failed to import block_sparse_attn_cuda: {e}") + finally: + # Restore original flags for future imports + sys.setdlopenflags(original_flags) + # NOTE: Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_blockmask.h: add head_idx to blockmask_ptr + +from .op_utils.vertical_slash_utils import build_index_local, convert_blockmask + +# ---------------------------------------------------------------------------- +# CUDA-based kernels (based on Block-Sparse-Attention) +def block_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_mask: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + granularity: int, + causal: bool, + step_idx: int=-1, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + cu_seqlens = torch.arange(0, (batch_size + 1) * num_tokens, step=num_tokens, dtype=torch.int32, device=q.device) + head_mask_type = torch.ones((num_qo_heads, ), dtype=torch.int32, device=q.device) # Block-Sparse + streaming_info = torch.zeros((num_qo_heads * 2), dtype=torch.int32, device=q.device) + row_blockmask = convert_blockmask_row_reverse(block_mask, causal=True) + + p_dropout = 0.0 + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block( + q.reshape((-1, num_qo_heads, head_dim)), + k.reshape((-1, num_kv_heads, head_dim)), + v.reshape((-1, num_kv_heads, head_dim)), + cu_seqlens, cu_seqlens, + granularity, granularity, + head_mask_type, + streaming_info, + row_blockmask, + num_tokens, num_tokens, + p_dropout, + softmax_scale, + causal, # is_causal + False, # exact_streaming + False, # return_softmax + -1, # window_size_left + -1, # window_size_right + None + ) + out = out.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) + return out, softmax_lse + + +def block_attn_bwd( + grad: torch.Tensor, + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_mask: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + granularity: int, + deterministic: bool, + causal: bool, + converted: bool = False, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + cu_seqlens = torch.arange(0, (batch_size + 1) * num_tokens, step=num_tokens, dtype=torch.int32, device=q.device) + head_mask_type = torch.ones((num_qo_heads, ), dtype=torch.int32, device=q.device) # Block-Sparse + streaming_info = torch.zeros((num_qo_heads * 2), dtype=torch.int32, device=q.device) + if converted: + col_blockmask = block_mask + else: + col_blockmask = convert_blockmask_col_reverse(block_mask, causal=True) + p_dropout = 0.0 + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dq, dk, dv, softmax_d = block_sparse_attn_cuda.bwd_block( + grad.reshape((-1, num_qo_heads, head_dim)), + q.reshape((-1, num_qo_heads, head_dim)), + k.reshape((-1, num_kv_heads, head_dim)), + v.reshape((-1, num_kv_heads, head_dim)), + o.reshape((-1, num_qo_heads, head_dim)), + softmax_lse, + dq.reshape((-1, num_qo_heads, head_dim)), + dk.reshape((-1, num_kv_heads, head_dim)), + dv.reshape((-1, num_kv_heads, head_dim)), + cu_seqlens, cu_seqlens, + granularity, granularity, + head_mask_type, + streaming_info, + col_blockmask, + num_tokens, num_tokens, + p_dropout, + softmax_scale, + True, # zero_tensors + causal, # is_causal + -1, # window_size_left + -1, # window_size_right + deterministic, + None, None + ) + dq = dq.reshape((batch_size, num_tokens, num_qo_heads, head_dim)) + dk = dk.reshape((batch_size, num_tokens, num_kv_heads, head_dim)) + dv = dv.reshape((batch_size, num_tokens, num_kv_heads, head_dim)) + return dq, dk, dv + + +@triton.jit +def _triton_bar_attn_fwd_kernel( + Q, K, V, sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_cz, stride_ch, stride_cm, stride_cr, + stride_iz, stride_ih, stride_im, stride_in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + m_mask = offs_m < num_tokens + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + step * stride_cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + (step + 1) * stride_cr) + bar_idx_ptr = bar_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + start_m * stride_im + + if bar_l >= bar_r: + return + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=m_mask[:, None], other=0) + q = (q * qk_scale).to(Q.type.element_ty) + + # loop over k, v and update accumulator + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(m_mask[:, None] & n_mask[None, :], qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + s_0 = tl.load(lse_ptrs, mask=m_mask, other=float("-inf")) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty), mask=m_mask[:, None]) + tl.store(lse_ptrs, s, mask=(m_mask & overflow_mask)) + + + +def bar_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int, + step: int = 0, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + _triton_bar_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, bar_cnt, bar_idx, o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, + num_warps=4, num_stages=2, + ) + return o, lse + + +@triton.jit +def _triton_bar_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_cz, stride_ch, stride_cm, stride_cr, + stride_iz, stride_ih, stride_im, stride_in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + m_mask = offs_m < num_tokens + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + step * stride_cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_cz + qo_head_idx * stride_ch + start_m * stride_cm + (step + 1) * stride_cr) + bar_idx_ptr = bar_idx + batch_idx * stride_iz + qo_head_idx * stride_ih + start_m * stride_im + + if bar_l >= bar_r: + return + + o = tl.load(o_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + do = tl.load(do_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs, mask=m_mask[:, None], other=0.) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs, mask=m_mask, other=0.) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[:, None] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # Computer qk + qk = tl.where(m_mask[:, None] & n_mask[None, :], float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + cols[:, None] * stride_dvn, dv_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + cols[:, None] * stride_dkn, dk_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs, mask=m_mask[:, None], other=0.).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty), mask=m_mask[:, None]) + + +def bar_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dk: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + dv: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + granularity: int, + deterministic: bool, + step: int = 0, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + dq = torch.zeros_like(q, dtype=torch.float32) if dq is None else dq.to(torch.float32) + dk = torch.zeros_like(k, dtype=torch.float32) if dk is None else dk.to(torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) if dv is None else dv.to(torch.float32) + _triton_bar_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + bar_cnt, bar_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + +# ---------------------------------------------------------------------------- +# Purely Triton-based kernels +@triton.jit +def _triton_block_attn_fwd_kernel( + Q, K, V, sm_scale, + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + if block_num <= 0: + return + + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(Q.type.element_ty) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs).to(tl.float32) + s_0 = tl.load(lse_ptrs) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty)) + tl.store(lse_ptrs, s, mask=overflow_mask) + +def triton_block_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_scale: float, + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + step: int = 0, + causal: bool = True, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = block_idx.shape[2] + + o = torch.zeros_like(q) + lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf + + _triton_block_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, + block_cnt, block_idx, + o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return o, lse + +@triton.jit +def _triton_block_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) + + +def triton_block_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + deterministic: bool, + step: int = 0, + causal: bool = True, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = block_idx.shape[2] + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k, dtype=torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) + + _triton_block_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + block_cnt, block_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + +@triton.jit +def _triton_block_bar_attn_fwd_kernel( + Q, K, V, sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + Out, # [BATCH, N_Q_HEADS, N_CTX, D_HEAD] + softmax_lse, # [BATCH, N_Q_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_1cz, stride_1ch, stride_1cm, stride_1cr, + stride_1iz, stride_1ih, stride_1im, stride_1in, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = batch_idx * stride_qz + qo_head_idx * stride_qh + kv_offset = batch_idx * stride_kz + kv_head_idx * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kd + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vd + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + + lse_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) + bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + if (bar_l >= bar_r) and (block_num <= 0): + return + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + # 1/ln2 = lne/ln2 = log2(e) => 2^(x / ln2) = 2^(x * log2(e)) = (2^(log2(e)))^x = e^x + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(Q.type.element_ty) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[None, :] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # Bar + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(n_mask[None, :], qk, float("-inf")) + qk = qk + tl.dot(q, k) + + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc = acc * acc_scale[:, None] + acc = acc + tl.dot(p.to(Q.type.element_ty), v) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O and LSE + acc_1 = acc / l_i[:, None] + s_1 = m_i * 0.69314718 + tl.math.log(l_i) + acc_0 = tl.load(o_ptrs).to(tl.float32) + s_0 = tl.load(lse_ptrs) + + overflow_mask = (s_0 - s_1) < 88.0 + + theta = tl.math.exp(s_0 - s_1) + alpha_0 = 1 / (1 + 1 / theta) + alpha_1 = 1 / (1 + theta) + acc = alpha_0[:, None] * acc_0 + alpha_1[:, None] * acc_1 + s = s_1 - tl.math.log(alpha_1) + + tl.store(o_ptrs, acc.to(Out.type.element_ty)) + tl.store(lse_ptrs, s, mask=overflow_mask) + + +def block_bar_attn_fwd( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + step: int = 0, + causal: bool = True, +): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + if o is None: + o = torch.zeros_like(q) + lse = torch.zeros((batch_size, num_qo_heads, num_tokens), dtype=torch.float32, device=q.device) - torch.inf + _triton_block_bar_attn_fwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, softmax_scale, bar_cnt, bar_idx, block_cnt, block_idx, o, lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + lse.stride(0), lse.stride(1), lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return o, lse + + +@triton.jit +def _triton_block_bar_attn_bwd_kernel( + Q, K, V, O, + DQ, DK, DV, DO, + sm_scale, + bar_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS, WORLD_SIZE + 1] + bar_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NNZ_V] + block_cnt, # [BATCH, N_Q_HEADS, NUM_ROWS] + block_idx, # [BATCH, N_Q_HEADS, NUM_ROWS, NUM_COLS] + softmax_lse, # [BATCH, N_HEADS, N_CTX] + stride_qz, stride_qh, stride_qm, stride_qd, + stride_kz, stride_kh, stride_kn, stride_kd, + stride_vz, stride_vh, stride_vn, stride_vd, + stride_oz, stride_oh, stride_om, stride_od, + stride_dqz, stride_dqh, stride_dqm, stride_dqd, + stride_dkz, stride_dkh, stride_dkn, stride_dkd, + stride_dvz, stride_dvh, stride_dvn, stride_dvd, + stride_doz, stride_doh, stride_dom, stride_dod, + stride_1cz, stride_1ch, stride_1cm, stride_1cr, + stride_1iz, stride_1ih, stride_1im, stride_1in, + stride_2cz, stride_2ch, stride_2cm, + stride_2iz, stride_2ih, stride_2im, stride_2in, + stride_sz, stride_sh, stride_sm, + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + qo_head_idx = tl.program_id(1) + batch_idx = tl.program_id(2) + kv_head_idx = qo_head_idx // (num_qo_heads // num_kv_heads) + + if start_m * BLOCK_M >= num_tokens: + return + + qk_scale = sm_scale * 1.44269504 + + # offset pointers for batch/head + Q += batch_idx * stride_qz + qo_head_idx * stride_qh + K += batch_idx * stride_kz + kv_head_idx * stride_kh + V += batch_idx * stride_vz + kv_head_idx * stride_vh + O += batch_idx * stride_oz + qo_head_idx * stride_oh + DQ += batch_idx * stride_dqz + qo_head_idx * stride_dqh + DK += batch_idx * stride_dkz + kv_head_idx * stride_dkh + DV += batch_idx * stride_dvz + kv_head_idx * stride_dvh + DO += batch_idx * stride_doz + qo_head_idx * stride_doh + + # loop over rows + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointers to value-like data + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_ptrs = K + offs_d[None, :] * stride_kd + v_ptrs = V + offs_d[None, :] * stride_vd + o_ptrs = O + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dk_ptrs = DK + offs_d[None, :] * stride_dkd + dv_ptrs = DV + offs_d[None, :] * stride_dvd + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + + l_ptrs = softmax_lse + batch_idx * stride_sz + qo_head_idx * stride_sh + offs_m * stride_sm + + bar_l = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + step * stride_1cr) + bar_r = tl.load(bar_cnt + batch_idx * stride_1cz + qo_head_idx * stride_1ch + start_m * stride_1cm + (step + 1) * stride_1cr) + bar_idx_ptr = bar_idx + batch_idx * stride_1iz + qo_head_idx * stride_1ih + start_m * stride_1im + + block_num = tl.load(block_cnt + batch_idx * stride_2cz + qo_head_idx * stride_2ch + start_m * stride_2cm) + block_idx_ptr = block_idx + batch_idx * stride_2iz + qo_head_idx * stride_2ih + start_m * stride_2im + + if (bar_l >= bar_r) and (block_num <= 0): + return + + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + d_i = tl.sum(o * do, axis=1) + + q = tl.load(q_ptrs) + do = do.to(DO.dtype.element_ty) + l_i = tl.load(l_ptrs) * 1.44269504 + + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if CAUSAL: + block_split = block_num - 2 + else: + block_split = block_num + + # Block + for start_n in range(0, block_split): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Block (Causal) + for start_n in range(max(block_split, 0), block_num): + block_off = tl.load(block_idx_ptr + start_n * stride_2in) * BLOCK_N + + # -- load k, v -- + k = tl.load(k_ptrs + block_off * stride_kn + offs_n[:, None] * stride_kn) + v = tl.load(v_ptrs + block_off * stride_vn + offs_n[:, None] * stride_vn) + + # Computer qk + qk = tl.where(offs_m[:, None] >= offs_n[None, :] + block_off, float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + block_off * stride_dvn + offs_n[:, None] * stride_dvn, dv_vals, sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + block_off * stride_dkn + offs_n[:, None] * stride_dkn, dk_vals, sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + # Bar + for start_n in range(bar_l, bar_r, BLOCK_N): + n_mask = start_n + offs_n < bar_r + cols = tl.load(bar_idx_ptr + (start_n + offs_n) * stride_1in, mask=n_mask, other=0) + + # -- load k, v -- + k = tl.load(k_ptrs + cols[:, None] * stride_kn) + v = tl.load(v_ptrs + cols[:, None] * stride_vn) + + # Computer qk + qk = tl.where(n_mask[None, :], float(0.), float("-inf")) + qk = qk + tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + dv_vals = tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do).to(tl.float32) + tl.atomic_add(dv_ptrs + cols[:, None] * stride_dvn, dv_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - d_i[:, None] + dp = dp + tl.dot(do, tl.trans(v)) + + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + + # compute dk = dot(ds.T, q) + dk_vals = tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q).to(tl.float32) + tl.atomic_add(dk_ptrs + cols[:, None] * stride_dkn, dk_vals, mask=n_mask[:, None], sem="relaxed") + + # compute dq + dq = dq + tl.dot(ds.to(Q.dtype.element_ty), k) + + dq_old = tl.load(dq_ptrs).to(tl.float32) + tl.store(dq_ptrs, (dq_old + dq).to(DQ.dtype.element_ty)) + + +def block_bar_attn_bwd( + grad: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + o: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dq: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + dk: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + dv: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + softmax_lse: torch.Tensor, # [batch_size, num_qo_heads, num_tokens] + softmax_scale: float, + bar_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, max_v_size] + bar_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, world_size + 1] + block_idx: torch.Tensor, # [batch_size, num_qo_heads, num_blocks, num_blocks] + block_cnt: torch.Tensor, # [batch_size, num_qo_heads, num_blocks] + granularity: int, + deterministic: bool, + step: int = 0, + causal: bool = True, +): + assert not deterministic + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_blocks = bar_idx.shape[2] + dq = torch.zeros_like(q) if dq is None else dq + dk = torch.zeros_like(k, dtype=torch.float32) if dk is None else dk.to(torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) if dv is None else dv.to(torch.float32) + _triton_block_bar_attn_bwd_kernel[(num_blocks, num_qo_heads, batch_size)]( + q, k, v, o, dq, dk, dv, grad, softmax_scale, + bar_cnt, bar_idx, block_cnt, block_idx, softmax_lse, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(1), k.stride(3), + v.stride(0), v.stride(2), v.stride(1), v.stride(3), + o.stride(0), o.stride(2), o.stride(1), o.stride(3), + dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3), + dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3), + dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3), + grad.stride(0), grad.stride(2), grad.stride(1), grad.stride(3), + bar_cnt.stride(0), bar_cnt.stride(1), bar_cnt.stride(2), bar_cnt.stride(3), + bar_idx.stride(0), bar_idx.stride(1), bar_idx.stride(2), bar_idx.stride(3), + block_cnt.stride(0), block_cnt.stride(1), block_cnt.stride(2), + block_idx.stride(0), block_idx.stride(1), block_idx.stride(2), block_idx.stride(3), + softmax_lse.stride(0), softmax_lse.stride(1), softmax_lse.stride(2), + step, num_qo_heads, num_kv_heads, num_tokens, + BLOCK_M=granularity, BLOCK_N=64, BLOCK_DMODEL=head_dim, CAUSAL=causal, + num_warps=4, num_stages=2, + ) + return dq, dk.to(dq.dtype), dv.to(dq.dtype) + + +# --------------------------------------------------------------------------------- +# Attention Classes +class MInferenceAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + + # Block Mask + out, softmax_lse = block_attn_fwd( + q, k, v, softmax_scale, + block_mask, + granularity=granularity, + causal=True, + ) + # Bar Mask + out, softmax_lse = bar_attn_fwd( + q, k, v, out, softmax_lse, softmax_scale, + bar_idx, bar_cnt, + granularity=granularity, + step=0, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask, bar_idx, bar_cnt = ctx.saved_tensors + # Block Mask + dq, dk, dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, ctx.softmax_scale, + block_mask, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + causal=True, + ) + + # Bar Mask + dq, dk, dv = bar_attn_bwd( + dout, q, k, v, out, dq, dk, dv, + softmax_lse, ctx.softmax_scale, + bar_idx, bar_cnt, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + step=0, + ) + return dq, dk, dv, None, None, None, None, None, None + + + +class MInferenceAttnTritonFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + block_mask, bar_idx, bar_cnt = build_index_local(q, k, v_size, s_size, num_tokens, granularity) + block_idx, block_cnt = convert_blockmask(block_mask, block_size_M=granularity, block_size_N=64) + + out, softmax_lse = block_bar_attn_fwd( + q, k, v, None, None, softmax_scale, + bar_idx, bar_cnt, block_idx, block_cnt, + granularity=granularity, + step=0, + ) + + ctx.save_for_backward(q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_idx, block_cnt, bar_idx, bar_cnt = ctx.saved_tensors + + # Bar Mask + dq, dk, dv = block_bar_attn_bwd( + dout, q, k, v, out, None, None, None, + softmax_lse, ctx.softmax_scale, + bar_idx, bar_cnt, block_idx, block_cnt, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + step=0, + ) + + return dq, dk, dv, None, None, None, None, None, None + +# --------------------------------------------------------------------------------- +# Wrapped Attention Functions +# -------------------------------------------- +# CUDA-Based +def minference_flash_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnFunc.apply( + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + + +def minference_flash_attn_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs +): + return minference_flash_attn_func( + qkv[:, :, 0], # q + qkv[:, :, 1], # k + qkv[:, :, 2], # v + *args, **kwargs + ) + + +def minference_flash_attn_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim]\ + *args, **kwargs +): + return minference_flash_attn_func( + q, + kv[:, :, 0], # k + kv[:, :, 1], # v + *args, **kwargs + ) + +# -------------------------------------------- +# Triton-Based +def minference_flash_attn_triton_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v_size: List[int], # [num_qo_heads] + s_size: List[int], # [num_qo_heads] + dropout_p: int = 0.0, + softmax_scale: float = None, + granularity: int = 128, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + return MInferenceAttnTritonFunc.apply( + q, + k, + v, + v_size, + s_size, + softmax_scale, + granularity, + return_attn_probs, + deterministic, + ) + +def minference_flash_attn_triton_qkvpacked_func( + qkv: torch.Tensor, # [batch_size, num_tokens, 3, num_heads, head_dim] + *args, **kwargs +): + return minference_flash_attn_triton_func( + qkv[:, :, 0], # q + qkv[:, :, 1], # k + qkv[:, :, 2], # v + *args, **kwargs + ) + + +def minference_flash_attn_triton_kvpacked_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + kv: torch.Tensor, # [batch_size, num_tokens, 2, num_kv_heads, head_dim] + *args, **kwargs +): + return minference_flash_attn_triton_func( + q, + kv[:, :, 0], # k + kv[:, :, 1], # v + *args, **kwargs + ) diff --git a/minference/ops/utils.py b/minference/ops/utils.py new file mode 100644 index 0000000..2c43f20 --- /dev/null +++ b/minference/ops/utils.py @@ -0,0 +1,75 @@ +import os +import torch + +def set_seed(seed=42): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def use_triton(): + return torch.version.hip is not None or os.getenv("FORCE_TRITON", "0") == "1" + + +def check_correctness_by_row( + seq_len, + tensor_var, + ref_tensor_var, + tensor_name, + ATOL=1e-2, + RTOL=1e-2, +): + if not torch.allclose(tensor_var, ref_tensor_var, atol=ATOL, rtol=RTOL): + for h in range(tensor_var.shape[2]): + for i in range(seq_len): + tensor_var_row = tensor_var[:, i, h] + ref_tensor_var_row = ref_tensor_var[:, i, h] + + if not torch.allclose(tensor_var_row, ref_tensor_var_row, atol=ATOL, rtol=RTOL): + print('-' * 60 + '\n') + print(f"Mismatched {tensor_name} at Head {h}, Row {i}:\n") + print(f"Computed:\n{tensor_var_row}\n") + print(f"Ref:\n{ref_tensor_var_row}\n") + + max_diff = torch.max(torch.abs(tensor_var_row - ref_tensor_var_row)) + print(f"Maximal difference: {max_diff.item()}\n") + return False + else: + print(f"All {tensor_name} values are correct within the specified tolerance.") + return True + +def check_correct_rate( + tensor_var, + ref_tensor_var, + ATOL=1e-2, + RTOL=1e-2, +): + assert len(tensor_var.shape) == 4, "Input tensor must be 3D (B, N, H, D)" + assert tensor_var.shape == ref_tensor_var.shape, ( + "Input and reference tensors must have the same shape" + ) + bsz, seq_len, num_heads, _ = tensor_var.shape + + # Boolean mask of element-wise closeness + elem_close = torch.isclose(tensor_var, ref_tensor_var, atol=ATOL, rtol=RTOL) + + # A row “matches” only if *all* its D elements are close + row_matches = elem_close.all(dim=-1) # shape (B, N, H) + + # Count rows that do *not* match + num_mismatching = (~row_matches).sum().item() + num_mismatching_prop = num_mismatching / (bsz * seq_len * num_heads) + return 1 - num_mismatching_prop + +def check_by_correct_rate( + tensor_var, + ref_tensor_var, + ATOL=1e-2, + RTOL=1e-2, + threshold=0.99 +): + """ + Check if the tensor_var is correct by comparing it with ref_tensor_var. + Returns True if the correctness rate is above 0.99, otherwise False. + """ + correctness_rate = check_correct_rate(tensor_var, ref_tensor_var, ATOL, RTOL) + return correctness_rate >= threshold \ No newline at end of file diff --git a/minference/ops/xattention_fa.py b/minference/ops/xattention_fa.py index 549d3fd..2383263 100644 --- a/minference/ops/xattention_fa.py +++ b/minference/ops/xattention_fa.py @@ -1,368 +1,369 @@ # Copyright (c) 2025 Microsoft # Licensed under The MIT License [see LICENSE for details] # Refer to the code in https://github.com/mit-han-lab/x-attention - +import math import torch -import triton -import triton.language as tl - - -@triton.jit -def softmax_fuse_block_sum_kernel_causal( - In, - Out, - scale, - input_stride_0, - input_stride_1, - input_stride_2, - output_stride_0, - output_stride_1, - output_stride_2, - real_q_len, - k_len, # we assume k_len is divisible by chunk size - chunk_start, - chunk_end, - segment_size: tl.constexpr, - block_size: tl.constexpr, -): - block_id = tl.program_id(0) - head_id = tl.program_id(1) - batch_id = tl.program_id(2) - - offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size - offs_k = tl.arange(0, segment_size) - - num_iters = k_len // segment_size - num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size - - m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") - l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 - - input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 - input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 - - output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 - output_ptr = output_ptr + tl.arange(0, segment_size // block_size) - - for iter in range(0, num_iters_before_causal): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - m_local = tl.max(X, 1) - m_new = tl.maximum(m_i, m_local) - alpha = tl.math.exp2(m_i - m_new) - - X = X - m_new[:, None] - l_local = tl.sum(tl.math.exp2(X), 1) - l_i = l_i * alpha + l_local - - m_i = m_new - - for iter in range(num_iters_before_causal, num_iters_before_causal + 1): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) - X = tl.where(mask, X, -1.0e6) - m_local = tl.max(X, 1) - m_new = tl.maximum(m_i, m_local) - alpha = tl.math.exp2(m_i - m_new) - - X = X - m_new[:, None] - l_local = tl.sum(tl.math.exp2(X), 1) - l_i = l_i * alpha + l_local - - m_i = m_new - - l_i_inv = 1.0 / l_i - - sum_mask = offs_q[:, None] < real_q_len - - for iter in range(0, num_iters_before_causal): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] - X = tl.where(sum_mask, X, 0) - X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) - X = tl.sum(X, 2) - X = tl.sum(X, 0) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - - for iter in range(num_iters_before_causal, num_iters_before_causal + 1): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size) - X = tl.where(mask, X, -1.0e6) - X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] - X = tl.where(sum_mask, X, 0) - X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) - X = tl.sum(X, 2) - X = tl.sum(X, 0) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - - for iter in range(num_iters_before_causal + 1, num_iters): - X = tl.zeros([segment_size // block_size], dtype=tl.float32) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - - -@triton.jit -def softmax_fuse_block_sum_kernel_non_causal( - In, - Out, - scale, - input_stride_0, - input_stride_1, - input_stride_2, - output_stride_0, - output_stride_1, - output_stride_2, - real_q_len, - k_len, # we assume k_len is divisible by chunk size - chunk_start, - chunk_end, - segment_size: tl.constexpr, - block_size: tl.constexpr, -): - block_id = tl.program_id(0) - head_id = tl.program_id(1) - batch_id = tl.program_id(2) - - offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size - offs_k = tl.arange(0, segment_size) - - num_iters = k_len // segment_size - - m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") - l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 - - input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 - input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 - - output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 - output_ptr = output_ptr + tl.arange(0, segment_size // block_size) - - for iter in range(0, num_iters): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - m_local = tl.max(X, 1) - m_new = tl.maximum(m_i, m_local) - alpha = tl.math.exp2(m_i - m_new) - - X = X - m_new[:, None] - l_local = tl.sum(tl.math.exp2(X), 1) - l_i = l_i * alpha + l_local - - m_i = m_new - - l_i_inv = 1.0 / l_i - - sum_mask = offs_q[:, None] < real_q_len - - for iter in range(0, num_iters): - X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale - X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] - X = tl.where(sum_mask, X, 0) - X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) - X = tl.sum(X, 2) - X = tl.sum(X, 0) - tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty)) - -@triton.jit -def flat_group_gemm_kernel(Q, K, Out, - stride_qz, stride_qh, stride_qn, - stride_kz, stride_kh, stride_kn, - stride_oz, stride_oh, stride_on, - chunk_start, chunk_end, - H: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ): - block_m = tl.program_id(0).to(tl.int64) - block_n = tl.program_id(1).to(tl.int64) - batch_id = tl.program_id(2).to(tl.int64) // H - head_id = tl.program_id(2).to(tl.int64) % H - - if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: - return - - Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * stride_qn - K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * stride_kn - - Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_qn + tl.arange(0, BLOCK_K)[None, :] - K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * stride_kn + tl.arange(0, BLOCK_K)[:, None] - - num_iters = HEAD_DIM // BLOCK_K - o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - for iter in range(num_iters): - q = tl.load(Q_ptrs + iter * BLOCK_K) - k = tl.load(K_ptrs + iter * BLOCK_K) - o += tl.dot(q, k) - - O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N - O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] - - tl.store(O_ptrs, o.to(Out.type.element_ty)) - -@triton.jit -def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, - stride_qz, stride_qh, stride_qn, - stride_kz, stride_kh, stride_kn, - stride_oz, stride_oh, stride_on, - chunk_start, chunk_end, - H: tl.constexpr, - STRIDE: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - is_caual: tl.constexpr, - ): - block_m = tl.program_id(0).to(tl.int64) - block_n = tl.program_id(1).to(tl.int64) - batch_id = tl.program_id(2).to(tl.int64) // H - head_id = tl.program_id(2).to(tl.int64) % H - - if is_caual: - if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: - return - - Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn - K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn - - Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1) - K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None] - - o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - for iter in range(STRIDE): - q = tl.load(Q_ptrs - iter * stride_qn) - k = tl.load(K_ptrs + iter * stride_kn) - o += tl.dot(q, k) - - O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N - O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] - - tl.store(O_ptrs, o.to(Out.type.element_ty)) - - -def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True): - batch_size, num_heads, q_len, k_len = attn_weights_slice.shape - assert q_len % reshaped_block_size == 0 - try: - assert k_len % segment_size == 0 - except: - assert False, f"xAttention error, k_len: {k_len}, segment size: {segment_size}" - assert segment_size % reshaped_block_size == 0 - assert attn_weights_slice.stride(-1) == 1 - - output = torch.empty((batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size), dtype=attn_weights_slice.dtype, device=attn_weights_slice.device) - - grid = (q_len // reshaped_block_size, num_heads, batch_size) - - if is_causal: - softmax_fuse_block_sum_kernel_causal[grid]( - attn_weights_slice, - output, - scale, - attn_weights_slice.stride(0), - attn_weights_slice.stride(1), - attn_weights_slice.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - real_q_len, - k_len, - chunk_start, - chunk_end, - segment_size, - reshaped_block_size, +from typing import List, Tuple, Dict, Any + +from minference.ops.pit_sparse_flash_attention_v3 import block_attn_fwd, block_attn_bwd +from minference.ops.op_utils.xattn_utils import ( + LN2, find_blocks_chunked, flat_group_gemm_fuse_reshape, softmax_fuse_block_sum +) + +def xattn_estimate( + query_states: torch.Tensor, # (batch_size, num_q_head, q_len, head_dim) + key_states: torch.Tensor, # (batch_size, num_kv_head, k_len, head_dim) + block_size, + stride, + norm=1, + softmax=True, + threshold=0.9, + chunk_size=16384, + select_mode="inverse", + use_triton=True, + causal=True, + kdb: int = 1, + keep_sink=False, + keep_recent=False, +) -> torch.Tensor: + batch_size, num_kv_head, k_len, head_dim = key_states.shape + batch_size, num_q_head, q_len, head_dim = query_states.shape + if num_q_head > num_kv_head: + key_states = torch.repeat_interleave(key_states.contiguous(), num_q_head // num_kv_head, dim=1) + + assert q_len % chunk_size == 0 + assert k_len % chunk_size == 0 + + q_chunk_num = q_len // chunk_size + q_block_num = q_len // block_size + + # assert num_kv_head == num_q_head + attn_sum_list = [] + simple_mask_list = [] + + if use_triton and ( + "100" not in torch.cuda.get_device_properties(torch.cuda.current_device()).name + ): + use_triton = False + print( + "setting use triton to false. Triton kernel not surpported on this device" ) - else: - softmax_fuse_block_sum_kernel_non_causal[grid]( - attn_weights_slice, - output, - scale, - attn_weights_slice.stride(0), - attn_weights_slice.stride(1), - attn_weights_slice.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - real_q_len, - k_len, - chunk_start, - chunk_end, - segment_size, - reshaped_block_size, + + num_strides_in_k = k_len // stride + + num_strides_per_chunk = chunk_size // stride + num_strides_per_block = block_size // stride + num_blocks_per_chunk = num_strides_per_chunk // num_strides_per_block + + for chunk_idx in range(q_chunk_num): + if kdb != 1: + raise ValueError("use_triton and kdb cannot be used together") + + q_chunk_start = chunk_idx * num_strides_per_chunk * stride + q_chunk_end = (chunk_idx + 1) * num_strides_per_chunk * stride + + q_chunk_start_stride = chunk_idx * num_strides_per_chunk + q_chunk_end_stride = (chunk_idx + 1) * num_strides_per_chunk + + # attn_weights_slice: (batch_size, num_heads, chunk_size // stride, kv_len // stride) + # (i.e. the attention sum of each SxS stride block) + # This step is agnostic to block size and just computes the attention sum in each stride block + attn_weights_slice = flat_group_gemm_fuse_reshape( + # query_states, key_states, stride, chunk_start, chunk_end, is_causal=True + query_states[:, :, q_chunk_start : q_chunk_end, :,], + key_states, + stride, + q_chunk_start_stride, + q_chunk_end_stride, + is_causal=causal, + ) + + # (batch_size, num_heads, q_block_num, k_block_num), + attn_sum = softmax_fuse_block_sum( + attn_weights_slice, # (batch_size, num_heads, chunk_size // stride, kv_len // stride) + num_strides_per_block, + min(4096, num_strides_per_block), + q_chunk_start_stride, q_chunk_end_stride, + num_strides_in_k, + 1 / LN2 / math.sqrt(head_dim) / stride / norm, + is_causal=causal, + ) + + + # (batch_size, head_num, num_blocks_per_chunk, block_num) + simple_mask = find_blocks_chunked( + attn_sum, + chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + del attn_weights_slice + + attn_sums = torch.cat(attn_sum_list, dim=-2) + + # (batch_size, head_num, num_blocks_per_chunk * q_chunk_num, block_num) + # i.e. (batch_size, head_num, q_block_num, q_block_num) + simple_masks = torch.cat(simple_mask_list, dim=-2) + + if causal: + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril( + torch.ones( + q_block_num, q_block_num, dtype=bool, device=key_states.device + ), + diagonal=0, + ), + simple_masks[:, :, -q_block_num:, -q_block_num:], + False, + ) + # print(f"{__name__} | simple_masks[:, :, -q_block_num:, -q_block_num:].shape {simple_masks[:, :, -q_block_num:, -q_block_num:].shape} after torch.where") + + + if keep_sink: + simple_masks[:, :, 0, :] = True + if keep_recent: + eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool) + eye_matrix_expanded = ( + eye_matrix.unsqueeze(0) + .unsqueeze(0) + .expand(1, num_kv_head, q_block_num, q_block_num) + ) + simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( + eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:] + ) + + # simple_masks -> (batch_size, head_num, q_block_num, q_block_num) + return attn_sums, simple_masks + +class XAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_indices, + xattn_params, # Dict[str, Any] + granularity, + causal, + softmax_scale, + return_softmax, + deterministic, + ): + batch_size, num_tokens, num_qo_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + q_block_num = (q.shape[1] + granularity - 1) // granularity + # (batch_size, head_num, q_block_num, q_block_num) + _, block_mask = xattn_estimate( + q.transpose(1, 2), k.transpose(1, 2), + granularity, + **xattn_params + ) + block_mask = block_mask[:, :, -q_block_num:, -q_block_num:].contiguous() + + # Block Mask + out, softmax_lse = block_attn_fwd( + q, k, v, softmax_scale, + block_mask, + granularity=granularity, + causal=causal, ) - return output - -def flat_group_gemm(query_states, key_states, chunk_start, chunk_end): - batch_size, num_heads, q_len, head_dim = query_states.shape - kv_len = key_states.shape[2] - - output = torch.empty((batch_size, num_heads, q_len, kv_len), dtype=query_states.dtype, device=query_states.device) - BLOCK_M = 128 - BLOCK_N = 128 - BLOCK_K = 64 - - grid = (q_len // BLOCK_M, kv_len // BLOCK_N, batch_size * num_heads) - flat_group_gemm_kernel[grid]( - query_states, - key_states, - output, - query_states.stride(0), - query_states.stride(1), - query_states.stride(2), - key_states.stride(0), - key_states.stride(1), - key_states.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - chunk_start, - chunk_end, - num_heads, - head_dim, - BLOCK_M, - BLOCK_N, - BLOCK_K, + ctx.save_for_backward(q, k, v, out, softmax_lse, block_mask) + ctx.granularity = granularity + ctx.deterministic = deterministic + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.head_indices = head_indices + + # print(f"{__name__} | out shape: {out.shape}") + return (out, softmax_lse, None) if return_softmax else out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, block_mask = ctx.saved_tensors + causal = ctx.causal + + # Block Mask + dq, dk, dv = block_attn_bwd( + dout, q, k, v, out, + softmax_lse, ctx.softmax_scale, + block_mask, + granularity=ctx.granularity, + deterministic=ctx.deterministic, + causal=causal, + ) + return dq, dk, dv, None, None, None, None, None, None, None + +def xattn_flash_attn_func( + q: torch.Tensor, # [batch_size, num_tokens, num_qo_heads, head_dim] + k: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [batch_size, num_tokens, num_kv_heads, head_dim] + head_indices: List[int], # [num_qo_heads] + xattn_params: Dict[str, Any], + granularity: int = 128, + dropout_p: int = 0.0, + softmax_scale: float = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite context window + alibi_slopes: Tuple[float, float] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +): + assert dropout_p == 0 + assert causal + assert window_size == (-1, -1) + assert alibi_slopes is None + + return XAttnFunc.apply( + q, k, v, + head_indices, + xattn_params, + granularity, + causal, + softmax_scale, + return_attn_probs, + deterministic, ) - return output - -def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True): - batch_size, num_heads, q_len, head_dim = query_states.shape - kv_len = key_states.shape[2] - - assert (key_states.shape[0] == batch_size) - assert (key_states.shape[1] == num_heads) - assert (key_states.shape[3] == head_dim) - - output = torch.empty((batch_size, num_heads, q_len // stride, kv_len // stride), dtype=query_states.dtype, device=query_states.device) - BLOCK_M = 128 - BLOCK_N = 128 - assert (q_len % (stride * BLOCK_M) == 0) - assert (kv_len % (stride * BLOCK_N) == 0) - - grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) - flat_group_gemm_fuse_reshape_kernel[grid]( - query_states, - key_states, - output, - query_states.stride(0), - query_states.stride(1), - query_states.stride(2), - key_states.stride(0), - key_states.stride(1), - key_states.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - chunk_start, - chunk_end, - num_heads, - stride, - head_dim, - BLOCK_M, - BLOCK_N, - is_causal, + + +if __name__ == "__main__": + import argparse + from flash_attn import flash_attn_func + from minference.ops.utils import set_seed + + + parser = argparse.ArgumentParser(description="XAttn Test") + parser.add_argument("--use_ones", action="store_true", help="Use ones for q, k, v") + parser.add_argument("--enable_sparse", action="store_true", help="Enable Sparse XAttenion") + parser.add_argument("--test_backward", action="store_true", help="Test backward pass") + parser.add_argument("--seq_len", type=int, default=16384, help="Sequence length") + args = parser.parse_args() + + ATOL, RTOL = 1e-2, 1e-2 + # dtype = torch.bfloat16 + dtype = torch.float16 + device = torch.device(f"cuda:0") + torch.cuda.set_device(device) + set_seed(2025) + + batch_size, seq_len, num_q_heads, head_dim = 1, args.seq_len, 8, 128 + num_kv_heads = 4 + head_indices = list(range(num_q_heads)) + + granularity = 128 + xattn_params = { + "stride": 16, + "norm": 1, + "softmax": True, + "threshold": 0.9 if args.enable_sparse else 1, + "chunk_size": 16384, + "select_mode": "inverse", + "use_triton": True, + "causal": True, + "kdb": 1, + "keep_sink": False, + "keep_recent": False + } + + if args.use_ones: + q = torch.ones((batch_size, seq_len, num_q_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + k = torch.ones((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + v = torch.ones((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + else: + q = torch.randn((batch_size, seq_len, num_q_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + k = torch.randn((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + v = torch.randn((batch_size, seq_len, num_kv_heads, head_dim), device=device, dtype=dtype, requires_grad=args.test_backward) + + # Clone inputs for reference implementation to ensure separate gradient computation + if args.test_backward: + q_ref = q.clone().detach().requires_grad_(True) + k_ref = k.clone().detach().requires_grad_(True) + v_ref = v.clone().detach().requires_grad_(True) + else: + q_ref, k_ref, v_ref = q, k, v + + out = xattn_flash_attn_func( + q, k, v, + head_indices, + xattn_params, + granularity=granularity, ) + print(f"out shape: {out.shape}") - return output + ref_out = flash_attn_func( + q_ref, k_ref, v_ref, + causal=True, + softmax_scale=head_dim ** (-0.5) + ) + + + # Compare out and ref_out + if not torch.allclose(out, ref_out, atol=ATOL, rtol=RTOL): + num_blocks = seq_len // granularity + for i in range(num_blocks): + start = i * granularity + end = (i + 1) * granularity + out_chunk = out[:, start:end, :, :] + ref_out_chunk = ref_out[:, start:end, :, :] + + print('-' * 60) + if not torch.allclose(out_chunk, ref_out_chunk, atol=ATOL, rtol=RTOL): + print(f"Forward Output mismatch at chunk {i}:") + print(f"Forward out_chunk: {out_chunk}") + print(f"Forward ref_out_chunk: {ref_out_chunk}") + else: + print(f"Forward Output match at chunk {i}") + else: + print("Forward Output match") + + + # Backward pass testing + if args.test_backward: + print("\nTesting backward pass...") + + # Create gradient for backward pass + grad_output = torch.randn_like(out) + grad_output_ref = grad_output.clone() + + # Backward pass for custom implementation + out.backward(grad_output) + + # Backward pass for reference implementation + ref_out.backward(grad_output_ref) + + # Compare gradients + print("\nGradient comparison:") + + # Compare q gradients + q_grad_match = torch.allclose(q.grad, q_ref.grad, atol=ATOL, rtol=RTOL) + print(f"q grad match: {q_grad_match}") + if not q_grad_match: + q_diff = (q.grad - q_ref.grad).abs() + print(f"q grad max diff: {q_diff.max().item()}, mean diff: {q_diff.mean().item()}") + + # Compare k gradients + k_grad_match = torch.allclose(k.grad, k_ref.grad, atol=ATOL, rtol=RTOL) + print(f"k grad match: {k_grad_match}") + if not k_grad_match: + k_diff = (k.grad - k_ref.grad).abs() + print(f"k grad max diff: {k_diff.max().item()}, mean diff: {k_diff.mean().item()}") + + # Compare v gradients + v_grad_match = torch.allclose(v.grad, v_ref.grad, atol=ATOL, rtol=RTOL) + print(f"v grad match: {v_grad_match}") + if not v_grad_match: + v_diff = (v.grad - v_ref.grad).abs() + print(f"v grad max diff: {v_diff.max().item()}, mean diff: {v_diff.mean().item()}") + + print(f"\nOverall gradient match: {q_grad_match and k_grad_match and v_grad_match}") diff --git a/mtraining/.gitignore b/mtraining/.gitignore new file mode 100644 index 0000000..736265e --- /dev/null +++ b/mtraining/.gitignore @@ -0,0 +1,13 @@ +**/__pycache__/ +MTraining.egg-info/ +**.ipynb +**/prof_logs/ +.vscode/ +.pytest_cache/ +*.log +**/draft/ +output/ +**/ring_attn_comp_data/ +**/ring_attn_comp_data/ +**/ring_attn_pt_logs/ +expr_data_store/ \ No newline at end of file diff --git a/mtraining/README.md b/mtraining/README.md new file mode 100644 index 0000000..6f08299 --- /dev/null +++ b/mtraining/README.md @@ -0,0 +1 @@ +# MTraining diff --git a/mtraining/__init__.py b/mtraining/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtraining/attn_funcs/__init__.py b/mtraining/attn_funcs/__init__.py new file mode 100644 index 0000000..60ce388 --- /dev/null +++ b/mtraining/attn_funcs/__init__.py @@ -0,0 +1,36 @@ +from typing import Dict, Callable + +from .dense_func import fa_attn_forward, stripe_ring_attention_forward, zigzag_ring_attention_forward +from .minfer_func import minfer_attention_forward +from .moba_func import moba_attention_forward +from .xattn_func import xattn_attention_forward + + +class AttnType: + BASELINE: str = "baseline" + ZIGZAG_RING: str = "zigzag_ring" + STRIPE_RING: str = "stripe_ring" + + MINFER: str = "minfer" + MOBA: str = "moba" + XATTN: str = "xattn" + +ATTN_TO_FUNC = { + AttnType.BASELINE: fa_attn_forward, + AttnType.ZIGZAG_RING: zigzag_ring_attention_forward, + AttnType.STRIPE_RING: stripe_ring_attention_forward, + + AttnType.MINFER: minfer_attention_forward, + AttnType.MOBA: moba_attention_forward, + AttnType.XATTN: xattn_attention_forward, +} + +def overwrite_attn_implementation( + attn_dict: Dict[str, Callable], + attn_type: AttnType, +): + attn_func: Callable = ATTN_TO_FUNC[attn_type] + print(f"Overwriting attention implementation to {attn_type} ({attn_func.__name__})") + + for attn_name in attn_dict: + attn_dict[attn_name] = attn_func \ No newline at end of file diff --git a/mtraining/attn_funcs/dense_func.py b/mtraining/attn_funcs/dense_func.py new file mode 100644 index 0000000..9c8df79 --- /dev/null +++ b/mtraining/attn_funcs/dense_func.py @@ -0,0 +1,278 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This file modifies the official modeling_llama.py file at runtime to +# 1. register the flash attention function to nnscaler and update related code +# 2. replace the un-fused RMSNorm with apex's fused version +import torch +from torch import Tensor +from transformers.utils import is_flash_attn_2_available +from typing import List, Optional, Tuple, Union, Any, Dict +if is_flash_attn_2_available(): + from flash_attn.bert_padding import pad_input + from flash_attn import flash_attn_func, flash_attn_varlen_func + +from nnscaler.ir import IRTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op + +from minference.dist_ops.zigzag_attention import zigzag_ring_flash_attn_func +from minference.dist_ops.striped_attention import stripe_flash_attn_func + +from .utils import nnscaler_upad_input + + +def fa_attn_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + causal=True, + ): + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = nnscaler_upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + +# --------------------------------------------------------------------------- +def zigzag_ring_attention_forward( + module: torch.nn.Module, + query: Tensor, # [B, H, N, D] + key: Tensor, + value: Tensor, + attention_mask: Optional[Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[Tensor, None]: + return wrap_zigzag_attn_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + layer_idx=module.layer_idx, + softmax_scale=scaling, + dropout_p=dropout, + causal=True, + ), None + + +def wrap_zigzag_attn_func( + q: Tensor, k: Tensor, v: Tensor, + layer_idx: int, + softmax_scale: Tensor=None, + dropout_p: float=0.0, + causal: bool=True, + window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None +) -> Tensor: + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert causal == True, "zigzag_ring is meaningless for causal=False" + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + output = zigzag_ring_flash_attn_func( + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + return output + +# --------------------------------------------------------------------------- +def stripe_ring_attention_forward( + module: torch.nn.Module, + query: Tensor, # [B, H, N, D] + key: Tensor, + value: Tensor, + attention_mask: Optional[Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[Tensor, None]: + return wrap_striped_attn_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + layer_idx=module.layer_idx, + softmax_scale=scaling, + dropout_p=dropout, + causal=True, + ), None + + +def wrap_striped_attn_func( + q: Tensor, k: Tensor, v: Tensor, layer_idx: int, + granularity: int=1, + softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None + ) -> Tensor: + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + output = stripe_flash_attn_func( + q, k, v, + layer_idx, + dropout_p, + softmax_scale, + granularity, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + return output + + + +# --------------------------------------------------------------------------- +def flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + if isinstance(attention_mask, IRTensor): + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^' + else: + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + + +def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +def ring_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno)(fa_attn_forward) +register_op(ring_attn_anno, emit_fn=emit_ring)(wrap_zigzag_attn_func) +register_op(ring_attn_anno, emit_fn=emit_ring)(wrap_striped_attn_func) diff --git a/mtraining/attn_funcs/minfer_func.py b/mtraining/attn_funcs/minfer_func.py new file mode 100644 index 0000000..c7025c6 --- /dev/null +++ b/mtraining/attn_funcs/minfer_func.py @@ -0,0 +1,377 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This file modifies the official modeling_llama.py file at runtime to +# 1. register the flash attention function to nnscaler and update related code +# 2. replace the un-fused RMSNorm with apex's fused version +import json +import torch +import logging +logger = logging.getLogger(__name__) + +from typing import List, Optional, Tuple, Dict, Callable +from transformers.utils import logging, is_flash_attn_2_available +if is_flash_attn_2_available(): from flash_attn import flash_attn_func + +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor +from nnscaler.ir.operator import IRFwOperation + +from minference.ops.utils import use_triton +from minference.ops.pit_sparse_flash_attention_v3 import ( + minference_flash_attn_func, minference_flash_attn_triton_func +) +from minference.dist_ops import ( + minfer_stripe_func, minfer_zigzag_func, minfer_dr_stripe_func, +) + + +# ======================================================= +def minfer_op( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + group: Optional[torch.distributed.ProcessGroup] = None, +): + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + if not use_triton(): + attn_output = minference_flash_attn_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) + else: + attn_output = minference_flash_attn_triton_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) + return attn_output.contiguous() + + +def minfer_stripe_op( + query_states: torch.Tensor, # [batch_size, num_heads, num_tokens, head_dim] + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + process_group: Optional[torch.distributed.ProcessGroup] = None, +): + if (process_group is None or len(process_group) == 1): + softmax_scale = query_states.shape[-1] ** (-0.5) + + output = flash_attn_func( + query_states.transpose(1, 2), + key_states.transpose(1, 2), + value_states.transpose(1, 2), + attn_dropout, softmax_scale, causal=True) + return output + group = DeviceGroup().get_group(process_group) + + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + attn_output = minfer_stripe_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + + return attn_output.contiguous() + + +def minfer_zigzag_op( + query_states: torch.Tensor, # [batch_size, num_heads, num_tokens, head_dim] + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + process_group: Optional[torch.distributed.ProcessGroup] = None, +): + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + softmax_scale = query_states.shape[-1] ** (-0.5) + output = flash_attn_func( + query_states.transpose(1, 2), + key_states.transpose(1, 2), + value_states.transpose(1, 2), + attn_dropout, softmax_scale, causal=True) + return output + group = DeviceGroup().get_group(process_group) + + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + if not use_triton(): + attn_output = minfer_zigzag_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + else: + raise NotImplementedError("Triton-only version is not implemented for MInfer w. zigzag") + return attn_output.contiguous() + +def minfer_dr_stripe_op( + query_states: torch.Tensor, # [batch_size, num_heads, num_tokens, head_dim] + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + bsz: int, + q_len: int, + head_dim: int, + layer_idx: int, + + pattern_dict: Dict[int, Tuple[str, int, int, int]], + attn_dropout: float=0., + granularity: int = 128, + process_group: Optional[torch.distributed.ProcessGroup] = None, +): + if (process_group is None or len(process_group) == 1): + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + softmax_scale = query_states.shape[-1] ** (-0.5) + + output = flash_attn_func( + query_states.transpose(1, 2), + key_states.transpose(1, 2), + value_states.transpose(1, 2), + attn_dropout, softmax_scale, causal=True) + return output + + group = DeviceGroup().get_group(process_group) + v_sizes = [pattern_dict[head_indices[idx].item()][1] for idx in range(query_states.size(1))] + s_sizes = [pattern_dict[head_indices[idx].item()][2] for idx in range(query_states.size(1))] + + attn_output = minfer_dr_stripe_func( + query_states.transpose(1, 2).contiguous(), + key_states.transpose(1, 2).contiguous(), + value_states.transpose(1, 2).contiguous(), + v_sizes, s_sizes, + layer_idx, + attn_dropout, + softmax_scale=None, + granularity=granularity, + causal=True, + window_size=(-1, -1), + deterministic=False, + return_attn_probs=False, + group=group, + ) # expect: b {q_anno} l^ vd^' + return attn_output.contiguous() + + + +MINFER_IMPLEMENTATIONS: Dict[str, Callable] = { + "default": minfer_op, + "stripe": minfer_stripe_op, + "zigzag": minfer_zigzag_op, + "dr_stripe": minfer_dr_stripe_op, +} + +def emit_minfer_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: + # partition on num_head dim + kw_pairs.append("process_group=None") + elif partition_dims[0] == 2: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +def minfer_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[1] != key_states.shape[1]: + assert query_states.shape[1] % key_states.shape[1] == 0 + group_size = query_states.shape[1] // key_states.shape[1] + assert query_states.shape[1] == value_states.shape[1] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b {q_anno} l^ hd^, b {kv_anno} s^ hd^, b {kv_anno} s^ vd^, {q_anno} -> b l^ {q_anno} vd^' + +def minfer_attn_ring_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[1] != key_states.shape[1]: + assert query_states.shape[1] % key_states.shape[1] == 0 + group_size = query_states.shape[1] // key_states.shape[1] + assert query_states.shape[1] == value_states.shape[1] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b {q_anno} l hd^, b {kv_anno} l hd^, b {kv_anno} l vd^, {q_anno} -> b l {q_anno} vd^' + +if __name__ != "__main__": + register_op(minfer_attn_anno)(minfer_op) + register_op(minfer_attn_ring_anno, emit_fn=emit_minfer_ring)(minfer_stripe_op) + register_op(minfer_attn_ring_anno, emit_fn=emit_minfer_ring)(minfer_zigzag_op) + register_op(minfer_attn_ring_anno, emit_fn=emit_minfer_ring)(minfer_dr_stripe_op) + +class MInferAttnFunc: + def __init__(self): + self.initialized = False + + def init_minfer_params( + self, + config_path: str, + minfer_implementation: str, # "fa", "stripe", "zigzag" + granularity: int = 128, + ): + assert minfer_implementation in MINFER_IMPLEMENTATIONS, f"minfer_implementation should be one of {MINFER_IMPLEMENTATIONS}, but got {self.minfer_implementation}" + self.minfer_implementation: str = minfer_implementation + + self.config_path = config_path + self.all_pattern_dict = json.load(open(self.config_path)) + self.granularity = granularity + + self.initialized = True + + def get_pattern_dict(self, layer_idx): + return {int(ii): jj for ii, jj in self.all_pattern_dict[layer_idx].items()} + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + head_indices: torch.Tensor, + attn_module_config: Dict[str, int], + attn_dropout: float=0.0, + ): + bsz, q_len = query_states.shape[0], query_states.shape[2] + head_dim, layer_idx = attn_module_config["head_dim"], attn_module_config["layer_idx"] + + pattern_dict = self.get_pattern_dict(layer_idx) + minfer_args = ( + query_states, key_states, value_states, + head_indices, + bsz, q_len, head_dim, layer_idx, + pattern_dict, attn_dropout, self.granularity, + ) + + if self.minfer_implementation == "default": + return minfer_op(*minfer_args) + elif self.minfer_implementation == "stripe": + return minfer_stripe_op(*minfer_args) + elif self.minfer_implementation == "zigzag": + return minfer_zigzag_op(*minfer_args) + elif self.minfer_implementation == "dr_stripe": + return minfer_dr_stripe_op(*minfer_args) + else: + raise ValueError(f"Unsupported minfer_implementation: {self.minfer_implementation}") + +def minfer_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, H, N, D] + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + attn_module_config = { + "num_heads": module.config.num_attention_heads, + "head_dim": module.head_dim, + "layer_idx": module.layer_idx, + } + head_indices = torch.arange(attn_module_config["num_heads"], device=query.device, dtype=torch.int32) + + return module.minfer_attn_func.forward( + query, key, value, head_indices, + attn_module_config, + dropout, + ), None diff --git a/mtraining/attn_funcs/moba_func.py b/mtraining/attn_funcs/moba_func.py new file mode 100644 index 0000000..b604dbb --- /dev/null +++ b/mtraining/attn_funcs/moba_func.py @@ -0,0 +1,178 @@ +import os +import yaml +import torch +import torch.distributed as dist + +from torch import Tensor +from typing import Tuple, Optional, List, Dict, Any + +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op + +from minference.ops.moba import moba_attn_func +from minference.ops.op_utils.moba_utils import MoBAConfig +from minference.dist_ops.moba_zigzag import moba_zigzag_func + +def load_moba_config(moba_config_dict: Dict[str, Any]): + moba_config = MoBAConfig(**moba_config_dict) + return moba_config + +def moba_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, H, N, D] + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + seq_len = query.shape[2] + moba_topk, moba_chunk_size = module.moba_topk, module.moba_chunk_size + implementation = module.implementation + + if implementation == "default": + return wrapped_moba_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + seq_len, + moba_topk, moba_chunk_size, + ), None + elif implementation == "zigzag": + layer_idx = module.layer_idx + return wrapped_moba_zigzag_func( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + seq_len, + moba_topk, moba_chunk_size, + layer_idx, + attention_mask, dropout, scaling, sliding_window, softcap, + ), None + else: + raise ValueError(f"Unsupported MoBA implementation: {implementation}. " + f"Supported implementations are 'default' and 'zigzag'.") + + +# ------------------------------------------ +def wrapped_moba_func( + q: Tensor, k: Tensor, v: Tensor, + seq_len: int, + moba_topk: int, moba_chunk_size: int, +): + return moba_attn_func( + q, k, v, + seq_len, + moba_chunk_size, moba_topk, + ) + +def wrapped_moba_zigzag_func( + query: Tensor, # [B, N, H, D] + key: Tensor, + value: Tensor, + seq_len: int, + moba_topk: int, moba_chunk_size: int, + layer_idx: int, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + process_group: Tuple[int]=None, +): + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + from flash_attn import flash_attn_func + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + output = flash_attn_func(query, key, value, 0.0, softmax_scale, True) + return output + + batch_size, block_seq_len, q_heads, head_dim = query.shape + assert batch_size == 1, "Current implementation only supports batch size = 1" + + local_process_group = DeviceGroup().get_group(process_group) + output = moba_zigzag_func( + query, key, value, + layer_idx, + seq_len, + moba_chunk_size, + moba_topk, + dropout, softmax_scale, + True, # causal, + (-1, -1), # window_size, + None, # alibi_slopes, + False, # deterministic, + False, # return_softmax, + local_process_group, # group + ).contiguous() + return output.view(batch_size, block_seq_len, q_heads, head_dim) + +# -------------------------------------------------- +def moba_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + +def moba_zigzag_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + num_q_heads, num_kv_heads = query_states.shape[2], key_states.shape[2] + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + group_size = num_q_heads // num_kv_heads + assert num_q_heads == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + attn_anno = f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + return attn_anno + + +def emit_moba_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +if __name__ != "__main__": + register_op(moba_attn_anno)(wrapped_moba_func) + register_op(moba_zigzag_attn_anno, emit_fn=emit_moba_zigzag)(wrapped_moba_zigzag_func) diff --git a/mtraining/attn_funcs/utils.py b/mtraining/attn_funcs/utils.py new file mode 100644 index 0000000..f01fc63 --- /dev/null +++ b/mtraining/attn_funcs/utils.py @@ -0,0 +1,49 @@ +import torch +import logging +logger = logging.getLogger(__name__) + +from transformers.utils import is_flash_attn_2_available, logging +from transformers.modeling_flash_attention_utils import _get_unpad_data +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, unpad_input # noqa + + + +def nnscaler_upad_input(query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + _, _, num_heads, _ = query_layer.shape + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) \ No newline at end of file diff --git a/mtraining/attn_funcs/xattn_func.py b/mtraining/attn_funcs/xattn_func.py new file mode 100644 index 0000000..abbf854 --- /dev/null +++ b/mtraining/attn_funcs/xattn_func.py @@ -0,0 +1,217 @@ +import os +import yaml +import copy +import torch +from torch import Tensor +import torch.distributed as dist + +from functools import partial +from flash_attn import flash_attn_func +from typing import Tuple, Optional, Dict, Any, List + +from nnscaler.runtime.device import DeviceGroup +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor +from nnscaler.ir.operator import IRFwOperation + +from minference.ops.xattention_fa import xattn_flash_attn_func +from minference.dist_ops.xattn_zigzag import xattn_zigzag_func + +def xattn_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, H, N, D] + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + granularity, xattn_params = module.granularity, module.xattn_params + implementation = module.implementation + layer_idx = module.layer_idx + + if implementation == "default": + head_indices = torch.arange(module.config.num_attention_heads, device=query.device) + return wrapped_xattn_func_( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + head_indices, + granularity, xattn_params, + dropout, scaling, sliding_window + ) + elif implementation == "zigzag": + return wrapped_xattn_zigzag_func_( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + layer_idx, + granularity, xattn_params, + dropout=dropout, scaling=scaling, sliding_window=sliding_window, + ) + else: + raise NotImplementedError(f"Unsupported implementation for xattn_attention_forward: {implementation}") + +# ------------------------------------------ +# Non-CP version +def wrapped_xattn_func_( + q: Tensor, k: Tensor, v: Tensor, # [B, N, H, D] + head_indices: torch.Tensor, + granularity: int, + xattn_params: Dict[str, Any], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, +): + return wrapped_xattn_func( + q, k, v, + head_indices, + granularity, xattn_params, + dropout, scaling, sliding_window + ), None + +def wrapped_xattn_func( + q: Tensor, k: Tensor, v: Tensor, + head_indices: torch.Tensor, + granularity: int, + xattn_params: Dict[str, Any], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, +): + sliding_window = -1 if sliding_window is None else sliding_window + return xattn_flash_attn_func( + q, k, v, + head_indices.detach().cpu().numpy().tolist(), + xattn_params, + granularity, + dropout_p=dropout, + softmax_scale=scaling, + causal=True, + window_size=(sliding_window, sliding_window), + alibi_slopes=None, + deterministic=False, + ) + + + + +# ------------------------------------------ +# Zigzag Version +def wrapped_xattn_zigzag_func_( + q: Tensor, k: Tensor, v: Tensor, # [B, N, H, D] + layer_idx: int, + granularity: int, + xattn_params: Dict[str, Any], + causal: bool=True, + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + process_group: Optional[dist.ProcessGroup] = None, +): + return wrapped_xattn_zigzag_func( + q, k, v, + layer_idx, + granularity, xattn_params, + causal=causal, + dropout=dropout, scaling=scaling, sliding_window=sliding_window, + ), None + + +def wrapped_xattn_zigzag_func( + q: Tensor, k: Tensor, v: Tensor, # [B, N, H, D] + layer_idx: int, + granularity: int, + xattn_params: Dict[str, Any], + causal: bool=True, + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + process_group: Tuple[int]=None, +): + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `scaling`, which is equivalent + # to the behavior of the original flash_attn_func. + if scaling is None: + scaling = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, scaling, causal) + return output + + group = DeviceGroup().get_group(process_group) + + xattn_params = copy.copy(xattn_params) + xattn_params.pop("chunk_size", None) + return xattn_zigzag_func( + q, k, v, + layer_idx, + xattn_params, + granularity, + dropout_p=dropout, + softmax_scale=scaling, + causal=causal, + group=group, + ).contiguous() + +def xattn_attn_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, {q_anno} -> b l^ {q_anno} vd^' + +def emit_xattn_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +def xattn_zigzag_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + +if __name__ != "__main__": + register_op(xattn_attn_anno)(wrapped_xattn_func) + register_op(xattn_zigzag_anno, emit_fn=emit_xattn_zigzag)(wrapped_xattn_zigzag_func) diff --git a/mtraining/experiments/active_param_configs/attn_only.txt b/mtraining/experiments/active_param_configs/attn_only.txt new file mode 100644 index 0000000..52aef4b --- /dev/null +++ b/mtraining/experiments/active_param_configs/attn_only.txt @@ -0,0 +1 @@ +self_attn \ No newline at end of file diff --git a/mtraining/experiments/active_param_configs/qk_proj_only.txt b/mtraining/experiments/active_param_configs/qk_proj_only.txt new file mode 100644 index 0000000..1a23205 --- /dev/null +++ b/mtraining/experiments/active_param_configs/qk_proj_only.txt @@ -0,0 +1,2 @@ +self_attn.q_proj +self_attn.k_proj diff --git a/mtraining/experiments/scripts/prolong_data_prepare.sh b/mtraining/experiments/scripts/prolong_data_prepare.sh new file mode 100644 index 0000000..2c4ec7a --- /dev/null +++ b/mtraining/experiments/scripts/prolong_data_prepare.sh @@ -0,0 +1,23 @@ +#!/usr/bin/bash + +# ------------------------------------------ +# Download data +# Prerequisite: sudo apt-get install git-lfs && git lfs install +RAW_DATASET_DIR="/path/to/datasets" +git clone https://huggingface.co/datasets/princeton-nlp/prolong-data-512K $RAW_DATASET_DIR/long-context-524288 +cd $RAW_DATASET_DIR/long-context-524288 +git lfs fetch +git lfs checkout + + +# ------------------------------------------ +# Data Processing +cd /path/to/mtraining +MODEL_ID="Qwen/Qwen2.5-3B" +PROCESSED_DATA_DIR="/path/to/processed_dataset" +torchrun --nproc_per_node=4\ + utils/data_utils/prolong.py \ + --model_id $MODEL_ID \ + --dataset_mix fixed_524288 \ + --dataset_path $RAW_DATASET_DIR/long-context-524288 \ + --save_path $PROCESSED_DATA_DIR diff --git a/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh new file mode 100755 index 0000000..6943697 --- /dev/null +++ b/mtraining/experiments/scripts/train_qwen_mini_ProLong512K.sh @@ -0,0 +1,105 @@ +#!/usr/bin/bash +echo $(which pip) +i=$(hostname | awk -F'-' '{print $2}') +NODE_RANK=$i +export NUM_NODES=1 +export REUSE_TYPE="match" +export FORCE_TRITON=1 + +export HF_TRUST_REMOTE_CODE=true +export HF_DATASETS_TRUST_REMOTE_CODE=true + +export MASTER_ADDR="node-0" +export MASTER_PORT="12345" + +export NNSCALER_HOME="/home/aiscuser/.conda/envs/mtrain/lib/python3.10/site-packages/nnscaler/autodist/" +export PYTHONPATH="${NNSCALER_HOME}:${PYTHONPATH}" + +# ----------------------------------------------- +# TODO: Basic Environment Settings +SEQUENCE_LENGTH=524288 +export GPU_NAME=A100 +export GPU_PER_NODE=4 +export WORLD_SIZE=4 +export GPU_SET="${GPU_NAME}_${WORLD_SIZE}" +export EXPR_HOME="/scratch/MInference/mtraining" +export NNSCALER_STORE="${EXPR_HOME}/experiments/expr_data_store/${GPU_SET}" +mkdir -p $NNSCALER_STORE +cd $EXPR_HOME + +# ------------------------------------------ +# /blob/nnscaler_store/MI300_8/minfer_qwen/qwen_fp090_512K_mini +export EXPR_DIR="minfer_qwen" # Name for the experiment set +export EXPR_NAME="qwen_fp090_512K_mini" # Name for the single experiment run +export MODEL_ID="Qwen/Qwen2.5-3B" +export DATASET_PATH="/scratch/nnscaler_store/prolong_fixed_filter_qwen_524288" +export MODEL_CONFIG_PATH="${EXPR_HOME}/model_configs/qwen2/lc_config_mini" +echo "Using model config path: $MODEL_CONFIG_PATH" +TRANSFER_CONFIG_DIR="none" +export TRAIN_ATTN_CONFIG_PATH="/scratch/MInference/mtraining/experiments/train_attn_configs/qwen_flex_090.yaml" +export ATTN_TYPE="minfer" + +# ------------------------------------------ +# Training Path settings +export TF_LOG_PATH="$NNSCALER_STORE/$EXPR_DIR/tf_logs" +export CKPT_PATH="$NNSCALER_STORE/$EXPR_DIR/$EXPR_NAME/checkpoints" +export COMPILE_PATH="$NNSCALER_STORE/compile_config/rank_${NODE_RANK}" +mkdir -p $TF_LOG_PATH +mkdir -p $CKPT_PATH +mkdir -p $COMPILE_PATH + +# ------------------------------------------- +# Training Settings +export GLOBAL_BATCH_SIZE=4 # TODO +export MICRO_BATCH_SIZE=1 +export MEM_CONSTRAINT=72 + +export NUM_ITER=10 +export NUM_EPOCH=0 + +export CKPT_SAVE_STEP=5 +export CKPT_SAVE_EPOCH=0 + +export CHECK_RESUME=0 +if [ "$CHECK_RESUME" -eq 1 ]; then + CHECK_RESUME="--check_resume" +else + CHECK_RESUME="" +fi + +# ------------------------------------------- +# Logging Path +export LOG_PATH="${NNSCALER_STORE}/${EXPR_DIR}/${EXPR_NAME}/rank_${NODE_RANK}" +mkdir -p $LOG_PATH +echo "Logging directed to $LOG_PATH/train.log" + +export TRACE_STRATEGY="reuse_cache" +echo $(which torchrun) +torchrun --nproc_per_node=$GPU_PER_NODE \ + --nnodes=$NUM_NODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --master_port=$MASTER_PORT \ + train.py --plan_ngpus $WORLD_SIZE \ + --runtime_ngpus $WORLD_SIZE \ + --name $EXPR_NAME \ + --seq_len $SEQUENCE_LENGTH \ + --attn_type $ATTN_TYPE \ + --train_attn_config_path $TRAIN_ATTN_CONFIG_PATH \ + --reuse_type $REUSE_TYPE \ + --model_id $MODEL_ID \ + --n_iter $NUM_ITER \ + --n_epochs $NUM_EPOCH \ + --global_batch_size $GLOBAL_BATCH_SIZE \ + --micro_batch_size $MICRO_BATCH_SIZE \ + --dataset_path $DATASET_PATH \ + --compile_save_path $COMPILE_PATH \ + --tf_log_dir $TF_LOG_PATH \ + --model_config_path $MODEL_CONFIG_PATH \ + --ckpt_save_dir $CKPT_PATH \ + --ckpt_n_step $CKPT_SAVE_STEP \ + --ckpt_n_epoch $CKPT_SAVE_EPOCH \ + --trace_strategy $TRACE_STRATEGY \ + --transfer_config_dir $TRANSFER_CONFIG_DIR \ + --mem_constraint $MEM_CONSTRAINT \ + $CHECK_RESUME > $LOG_PATH/train.log 2>&1 \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/moba_256k_s95.yaml b/mtraining/experiments/train_attn_configs/moba_256k_s95.yaml new file mode 100644 index 0000000..712cae5 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/moba_256k_s95.yaml @@ -0,0 +1,2 @@ +moba_chunk_size: 4096 +moba_topk: 6 \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/moba_512k_s95.yaml b/mtraining/experiments/train_attn_configs/moba_512k_s95.yaml new file mode 100644 index 0000000..348f394 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/moba_512k_s95.yaml @@ -0,0 +1,2 @@ +moba_chunk_size: 4096 +moba_topk: 12 \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/qwen_flex_090.yaml b/mtraining/experiments/train_attn_configs/qwen_flex_090.yaml new file mode 100644 index 0000000..f584796 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/qwen_flex_090.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_flex_0.90 +implementation: stripe \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/qwen_flex_095.yaml b/mtraining/experiments/train_attn_configs/qwen_flex_095.yaml new file mode 100644 index 0000000..b3685e6 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/qwen_flex_095.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_flex_0.95 +implementation: stripe \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/qwen_mf_dr_stripe.yaml b/mtraining/experiments/train_attn_configs/qwen_mf_dr_stripe.yaml new file mode 100644 index 0000000..b0c2051 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/qwen_mf_dr_stripe.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_kv_out_v32_fit_o_best_pattern +implementation: dr_stripe \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/qwen_mf_stripe.yaml b/mtraining/experiments/train_attn_configs/qwen_mf_stripe.yaml new file mode 100644 index 0000000..3601afd --- /dev/null +++ b/mtraining/experiments/train_attn_configs/qwen_mf_stripe.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_kv_out_v32_fit_o_best_pattern +implementation: stripe \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/qwen_mf_zigzag.yaml b/mtraining/experiments/train_attn_configs/qwen_mf_zigzag.yaml new file mode 100644 index 0000000..a6afdaa --- /dev/null +++ b/mtraining/experiments/train_attn_configs/qwen_mf_zigzag.yaml @@ -0,0 +1,2 @@ +pattern_config_name: Qwen2.5_3B_kv_out_v32_fit_o_best_pattern +implementation: zigzag \ No newline at end of file diff --git a/mtraining/experiments/train_attn_configs/xattn_default.yaml b/mtraining/experiments/train_attn_configs/xattn_default.yaml new file mode 100644 index 0000000..6cf6980 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/xattn_default.yaml @@ -0,0 +1,12 @@ +granularity: 128 +stride: 16 +norm: 1 +softmax: true +threshold: 0.9 +chunk_size: 16384 +select_mode: inverse +use_triton: true +causal: true +kdb: 1 +keep_sink: false +keep_recent: false diff --git a/mtraining/experiments/train_attn_configs/xattn_zigzag_s16.yaml b/mtraining/experiments/train_attn_configs/xattn_zigzag_s16.yaml new file mode 100644 index 0000000..52df90b --- /dev/null +++ b/mtraining/experiments/train_attn_configs/xattn_zigzag_s16.yaml @@ -0,0 +1,12 @@ +implementation: zigzag +granularity: 128 +stride: 16 +norm: 1 +softmax: true +threshold: 0.9 +select_mode: inverse +use_triton: true +causal: true +kdb: 1 +keep_sink: false +keep_recent: false diff --git a/mtraining/experiments/train_attn_configs/xattn_zigzag_s16_t85.yaml b/mtraining/experiments/train_attn_configs/xattn_zigzag_s16_t85.yaml new file mode 100644 index 0000000..d420943 --- /dev/null +++ b/mtraining/experiments/train_attn_configs/xattn_zigzag_s16_t85.yaml @@ -0,0 +1,12 @@ +implementation: zigzag +granularity: 128 +stride: 16 +norm: 1 +softmax: true +threshold: 0.85 +select_mode: inverse +use_triton: true +causal: true +kdb: 1 +keep_sink: false +keep_recent: false diff --git a/mtraining/model_configs/__init__.py b/mtraining/model_configs/__init__.py new file mode 100644 index 0000000..e65a160 --- /dev/null +++ b/mtraining/model_configs/__init__.py @@ -0,0 +1,29 @@ +from .phi3 import ( + Phi3Config, Phi3ForCausalLM, Phi3Attention, + apply_rotary_pos_emb, repeat_kv, + PHI_ATTN_FUNCS +) + +from .qwen2 import ( + Qwen2Config, Qwen2ForCausalLM, Qwen2Attention, + apply_rotary_pos_emb, repeat_kv, + QWEN_ATTN_FUNCS +) + + +MODEL_TO_ATTN_FUNC = { + "microsoft/Phi-3-mini-4k-instruct": PHI_ATTN_FUNCS, + "Qwen/Qwen2.5-3B": QWEN_ATTN_FUNCS +} + + +MODEL_ID_TO_MODEL_CLS = { + "microsoft/Phi-3-mini-4k-instruct": Phi3ForCausalLM, + "Qwen/Qwen2.5-3B": Qwen2ForCausalLM +} + +MODEL_ID_TO_PREFIX = { + "microsoft/Phi-3-mini-4k-instruct": "Phi3", + "Qwen/Qwen2.5-3B": "Qwen2", +} + diff --git a/mtraining/model_configs/phi3/__init__.py b/mtraining/model_configs/phi3/__init__.py new file mode 100644 index 0000000..85b6f8f --- /dev/null +++ b/mtraining/model_configs/phi3/__init__.py @@ -0,0 +1,5 @@ +from .modelling_phi import ( + Phi3Config, Phi3ForCausalLM, Phi3Attention, + apply_rotary_pos_emb, repeat_kv, + PHI_ATTN_FUNCS +) \ No newline at end of file diff --git a/mtraining/model_configs/phi3/configuration_phi3.py b/mtraining/model_configs/phi3/configuration_phi3.py new file mode 100644 index 0000000..7804010 --- /dev/null +++ b/mtraining/model_configs/phi3/configuration_phi3.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json", + "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json", +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/mtraining/model_configs/phi3/lc_config/configuration_phi3.py b/mtraining/model_configs/phi3/lc_config/configuration_phi3.py new file mode 100644 index 0000000..7804010 --- /dev/null +++ b/mtraining/model_configs/phi3/lc_config/configuration_phi3.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json", + "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json", +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/mtraining/model_configs/phi3/lc_config_mini/configuration_phi3.py b/mtraining/model_configs/phi3/lc_config_mini/configuration_phi3.py new file mode 100644 index 0000000..7804010 --- /dev/null +++ b/mtraining/model_configs/phi3/lc_config_mini/configuration_phi3.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json", + "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json", +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/mtraining/model_configs/phi3/modelling_phi.py b/mtraining/model_configs/phi3/modelling_phi.py new file mode 100644 index 0000000..0965eaf --- /dev/null +++ b/mtraining/model_configs/phi3/modelling_phi.py @@ -0,0 +1,1185 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi3/modular_phi3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +from .configuration_phi3 import Phi3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + +PHI_ATTN_FUNCS = ALL_ATTENTION_FUNCTIONS.copy() + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = PHI_ATTN_FUNCS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, config: Phi3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + elif self.rope_type == "longrope": + self._longrope_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _longrope_frequency_update(self, position_ids, device): + """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" + seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = self.config.original_max_position_embeddings + else: + original_max_position_embeddings = self.config.max_position_embeddings + if seq_len > original_max_position_embeddings: + if not hasattr(self, "long_inv_freq"): + self.long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + else: + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + + +PHI3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3PreTrainedModel(PreTrainedModel): + config_class = Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3Model(Phi3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Phi3Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +@add_start_docstrings( + """ + The Phi3 Model transformer with a sequence classification head on top (linear layer). + + [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3_START_DOCSTRING, +) +class Phi3ForSequenceClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Phi3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PHI3_START_DOCSTRING, +) +class Phi3ForTokenClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Phi3PreTrainedModel", + "Phi3Model", + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", +] \ No newline at end of file diff --git a/mtraining/model_configs/qwen2/__init__.py b/mtraining/model_configs/qwen2/__init__.py new file mode 100644 index 0000000..f9c07a0 --- /dev/null +++ b/mtraining/model_configs/qwen2/__init__.py @@ -0,0 +1,7 @@ +from .modeling_qwen2 import ( + Qwen2ForCausalLM, Qwen2Attention, + apply_rotary_pos_emb, repeat_kv, + QWEN_ATTN_FUNCS +) + +from .configuration_qwen2 import Qwen2Config \ No newline at end of file diff --git a/mtraining/model_configs/qwen2/configuration_qwen2.py b/mtraining/model_configs/qwen2/configuration_qwen2.py new file mode 100644 index 0000000..1c85806 --- /dev/null +++ b/mtraining/model_configs/qwen2/configuration_qwen2.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mtraining/model_configs/qwen2/lc_config/configuration_qwen2.py b/mtraining/model_configs/qwen2/lc_config/configuration_qwen2.py new file mode 100644 index 0000000..1c85806 --- /dev/null +++ b/mtraining/model_configs/qwen2/lc_config/configuration_qwen2.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mtraining/model_configs/qwen2/lc_config_mini/configuration_qwen2.py b/mtraining/model_configs/qwen2/lc_config_mini/configuration_qwen2.py new file mode 100644 index 0000000..1c85806 --- /dev/null +++ b/mtraining/model_configs/qwen2/lc_config_mini/configuration_qwen2.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mtraining/model_configs/qwen2/modeling_qwen2.py b/mtraining/model_configs/qwen2/modeling_qwen2.py new file mode 100644 index 0000000..0b7175a --- /dev/null +++ b/mtraining/model_configs/qwen2/modeling_qwen2.py @@ -0,0 +1,1136 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + QuestionAnsweringModelOutput, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" +_CONFIG_FOR_DOC = "Qwen2Config" + +QWEN_ATTN_FUNCS = ALL_ATTENTION_FUNCTIONS.copy() + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. # (B, H, S, D) + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) # (B, S, D) -> (B, 1, S, D) + sin = sin.unsqueeze(unsqueeze_dim) # (B, S, D) -> (B, 1, S, D) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = QWEN_ATTN_FUNCS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (B, D // 2, 1) + position_ids_expanded = position_ids[:, None, :].float() # (B, 1, S) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (B, S, D // 2) + emb = torch.cat((freqs, freqs), dim=-1) # (B, S, D) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, # example: torch.arange(0, 1024).unsqueeze(0) (shape: [B, S]) + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/mtraining/requirements.txt b/mtraining/requirements.txt new file mode 100644 index 0000000..f76ab56 --- /dev/null +++ b/mtraining/requirements.txt @@ -0,0 +1,6 @@ +transformers==4.48.0 +datasets==2.20.0 +tensorboard + +# For Data Preparation +mosaicml-streaming==0.8.1 \ No newline at end of file diff --git a/mtraining/setup.py b/mtraining/setup.py new file mode 100644 index 0000000..b1c86ad --- /dev/null +++ b/mtraining/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, find_packages + +setup( + name="mtraining", # Name of your project + version="0.1.0", + packages=find_packages(), # Automatically discover all packages + install_requires=[], # List dependencies if any (or use requirements.txt) + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.10", # Specify the Python version +) diff --git a/mtraining/setup.sh b/mtraining/setup.sh new file mode 100755 index 0000000..923fb32 --- /dev/null +++ b/mtraining/setup.sh @@ -0,0 +1,36 @@ +#!/usr/bin/bash +set -e # Exit on first error + +BASE_DIR="$(cd "$(dirname "$0")" && pwd)" +echo $BASE_DIR +PIP="$(which pip)" + +if command -v nvidia-smi +then + $PIP install ninja cmake wheel pybind11 + $PIP install -r "${BASE_DIR}/requirements.txt" + $PIP install git+https://github.com/microsoft/nnscaler.git@2368540417bc3b77b7e714d3f1a0de8a51bb66e8 + $PIP install "rotary-emb @ git+https://github.com/Dao-AILab/flash-attention.git@9356a1c0389660d7e231ff3163c1ac17d9e3824a#subdirectory=csrc/rotary" + $PIP install "block_sparse_attn @ git+https://github.com/HalberdOfPineapple/flash-attention.git@block-sparse" + $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 + $PIP install torch==2.6.0 torchvision==0.21.0 + $PIP install triton==3.0.0 +elif command -v rocm-smi # TODO: to verify the correctness of dependencies in ROCm environment +then + $PIP install ninja cmake wheel pybind11 + $PIP install --pre torch==2.3.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0 + $PIP install git+https://github.com/OpenAI/triton.git@e192dba#subdirectory=python + $PIP install git+https://github.com/Dao-AILab/flash-attention.git@v2.7.4.post1 + $PIP install -r "${BASE_DIR}/requirements.txt" + $PIP install git+https://github.com/microsoft/nnscaler.git@2368540417bc3b77b7e714d3f1a0de8a51bb66e8 +else + echo "ERROR: both nvidia-smi and rocm-smi not found" + exit 1 +fi + +# Get the path to nnscaler and write its path to PYTHONPATH in ~/.profile +NNSCALER_HOME=$(python -c "import nnscaler; print(nnscaler.__path__[0])") +echo "export NNSCALER_HOME=${NNSCALER_HOME}" >> ~/.profile +echo "export PYTHONPATH=${NNSCALER_HOME}:\${PYTHONPATH}" >> ~/.profile +source ~/.profile +$PIP install -e $BASE_DIR \ No newline at end of file diff --git a/mtraining/train.py b/mtraining/train.py new file mode 100644 index 0000000..dfa4577 --- /dev/null +++ b/mtraining/train.py @@ -0,0 +1,550 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import os +import yaml +import torch +import logging +import argparse +import numpy as np + +from datasets import load_from_disk +from typing import Dict, List, Optional +from transformers.modeling_utils import PreTrainedModel +from transformers import AutoConfig, DataCollatorForLanguageModeling + +from nnscaler.cli.trainer_args import ( + CheckpointConfig, + DatasetConfig, + HookMapConfig, + ModelConfig, + OptimizerConfig, + DataloaderConfig, + LogConfig, + DatasetSamplerConfig, +) +from nnscaler.parallel import ComputeConfig +from nnscaler.utils import set_default_logger_level +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW +from nnscaler.cli.loggers.tensorboard import TensorBoardLogger + +from minference.models_patch import MInference +from minference.minference_configuration import MInferenceConfig +from minference.configs.model2path import BASE_DIR as SPARSE_PATTERN_CONFIG_DIR + +from mtraining.attn_funcs import AttnType, overwrite_attn_implementation +from mtraining.trainer import CustomTrainer as Trainer, CustomTrainerArgs as TrainerArgs +from mtraining.model_configs import MODEL_TO_ATTN_FUNC, MODEL_ID_TO_MODEL_CLS, MODEL_ID_TO_PREFIX + +from mtraining.utils.expr_data import update_expr_data +from mtraining.utils.paths import update_expr_data_save_path +from mtraining.utils.general import freeze_model_params, load_comm_profile_data +from mtraining.utils import chunk_linear_cross_entropy, get_tokenizer, aggregate_outputs_fn, get_resume_path + +IGNORE_IDX = -100 +logger = logging.getLogger(__name__) +set_default_logger_level('INFO') + +def init_by_attn_type(model_id: str, attn_type: AttnType): + attn_dict = MODEL_TO_ATTN_FUNC[model_id] + + if attn_type == AttnType.BASELINE: + print(f"{__name__} | Using Baseline Model...") + elif attn_type == AttnType.ZIGZAG_RING: + print(f"{__name__} | Using Ring Zigzag Attention-equipped Model ...") + elif attn_type == AttnType.STRIPE_RING: + print(f"{__name__} | Using Ring Stripe Attention-equipped Model ...") + elif attn_type == AttnType.MINFER: + print(f"{__name__} | Using MInference-equipped Model ...") + elif attn_type == AttnType.MOBA: + print(f"{__name__} | Using MoBA-equipped Model ...") + elif attn_type == AttnType.XATTN: + print(f"{__name__} | Using XAttention-equipped Model ...") + else: + raise ValueError(f"Invalid attn_type: {attn_type}") + + overwrite_attn_implementation(attn_dict, attn_type) + + +class BaselineModel(torch.nn.Module): + def __init__( + self, + model_id, + config_path: str=None, + # merged_ckpt_path: str=None, + active_param_config_path: str=None + ): + super().__init__() + model_cls: PreTrainedModel = MODEL_ID_TO_MODEL_CLS[model_id] + + if not config_path: + self.model = model_cls.from_pretrained( + model_id, + attn_implementation='flash_attention_2' + ) + else: + model_config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) + model_config._attn_implementation = 'flash_attention_2' + self.model = model_cls.from_pretrained( + model_id, + config=model_config, + ) + + if active_param_config_path: + freeze_model_params(self.model, active_param_config_path) + + print(f'{__class__.__name__} Self-Attention Class: {self.model.model.layers[0].self_attn.__class__.__name__}') + + def forward(self, samples): + with torch.autocast(device_type="cuda", dtype=self.model.config.torch_dtype): + outputs = self.model.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + hidden_states = outputs[0] + losses = chunk_linear_cross_entropy(hidden_states, self.model.lm_head.weight, samples['target'], IGNORE_IDX, 1024) + loss = torch.sum(losses) + + return loss, loss.data, samples['ntokens'], samples['nsentences'] + +class MInferModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + minfer_config: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + + # ---------------------------------------------- + # Ring Attention specific + granularity: int = minfer_config.pop('granularity', 128) + + # -------------------------------------------- + # MInference Setup + minfer_implementation: str = minfer_config.pop('implementation', 'default') + minfer_attn_type = minfer_config.pop('attn_type', 'minference') + minfer_config['config_path'] = os.path.join( + SPARSE_PATTERN_CONFIG_DIR, + f'{minfer_config.pop("pattern_config_name")}.json', + ) + print(f"{__name__} | MInference Pattern Config Path: {minfer_config['config_path']}") + minfer = MInference( + attn_type=minfer_attn_type, + model_name=model_id, + **minfer_config, + ) + minfer_config: MInferenceConfig = minfer.config + + # -------------------------------------------- + # We still need to attach the function object to the model + # otherwise the states of the function will be lost as nnscaler will only load the model from file + # but not call this procedure again + from mtraining.attn_funcs.minfer_func import MInferAttnFunc + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.minfer_attn_func = MInferAttnFunc() + m.minfer_attn_func.init_minfer_params( + config_path=minfer_config.config_path, + minfer_implementation=minfer_implementation, + granularity=granularity, + ) + self.model.apply(update_module) + +class XAttnModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + xattn_params: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + + # -------------------------------------------- + implementation: str = xattn_params.pop('implementation', 'fa') + granularity: int = xattn_params.pop('granularity', 128) + + # -------------------------------------------- + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.granularity = granularity + m.xattn_params = xattn_params + m.implementation = implementation + self.model.apply(update_module) + + +class MoBAModel(BaselineModel): + def __init__( + self, + model_id, + config_path: str=None, + moba_config_dict: Dict={}, + **kwargs, + ): + super().__init__( + model_id=model_id, + config_path=config_path, + **kwargs, + ) + # -------------------------------------------- + moba_topk, moba_chunk_size = moba_config_dict["moba_topk"], moba_config_dict["moba_chunk_size"] + moba_implementation = moba_config_dict.get("implementation", 'default') + + # -------------------------------------------- + # We still need to attach the function object to the model + # otherwise the states of the function will be lost as nnscaler will only load the model from file + # but not call this procedure again + Attention = self.model.model.layers[0].self_attn.__class__ + def update_module(m): + if isinstance(m, Attention): + m.moba_topk = moba_topk + m.moba_chunk_size = moba_chunk_size + m.implementation = moba_implementation + self.model.apply(update_module) + +ATTN_TO_MODEL = { + AttnType.BASELINE: BaselineModel, + AttnType.STRIPE_RING: BaselineModel, + AttnType.ZIGZAG_RING: BaselineModel, + + AttnType.MINFER: MInferModel, + AttnType.MOBA: MoBAModel, + AttnType.XATTN: XAttnModel, +} + + +def load_train_attn_config(train_attn_config_path: str) -> MInferenceConfig: + if train_attn_config_path is None or train_attn_config_path.lower() == 'none': + train_attn_config_path = None + + if train_attn_config_path is None: + print(f"{__name__} | Use empty Training Attention config") + train_attn_config = {} + elif os.path.exists(train_attn_config_path): + print(f"{__name__} | Training Attention config found in {train_attn_config_path}.") + with open(train_attn_config_path, 'r') as f: + train_attn_config = yaml.safe_load(f) + print('-' * 20) + print("Training Attention Config:") + print(train_attn_config) + print('-' * 20) + else: + raise FileNotFoundError(f"Training Attention config {train_attn_config_path} not found. Exit.") + return train_attn_config + +def build_model_args(args, train_attn_config: MInferenceConfig) -> Dict: + model_args = { + 'model_id': args.model_id, + 'config_path': args.model_config_path, + "active_param_config_path": args.active_param_config_path, + } + if args.attn_type == AttnType.MINFER: + model_args['minfer_config'] = train_attn_config + elif args.attn_type == AttnType.XATTN: + model_args['xattn_params'] = train_attn_config + elif args.attn_type == AttnType.MOBA: + model_args['moba_config_dict'] = train_attn_config + + return model_args + + +def main(args): + update_expr_data_save_path(args.ckpt_save_dir, args.compile_save_path) + update_expr_data(args) + + local_rank = int(os.environ["LOCAL_RANK"]) + if local_rank == 0: load_comm_profile_data(args) + + init_by_attn_type(args.model_id, args.attn_type) + train_attn_config = load_train_attn_config(args.train_attn_config_path) + + # --------------------------------- + # Compute config + if args.run_mode == 'compile': + if args.runtime_ngpus is None: + raise ValueError('runtime_ngpus must be specified in compile mode') + runtime_ngpus = args.runtime_ngpus + elif args.run_mode == 'run': + world_size = int(os.getenv('WORLD_SIZE')) + if args.runtime_ngpus is None: + runtime_ngpus = world_size + else: + if args.runtime_ngpus != world_size: + raise ValueError(f'runtime_ngpus ({args.runtime_ngpus}) must match the number of GPUs in run mode ({world_size})') + runtime_ngpus = args.runtime_ngpus + + if runtime_ngpus % args.plan_ngpus != 0: + raise ValueError('runtime_ngpus must be a multiple of plan_ngpus') + + scaling_factor: int = runtime_ngpus // args.plan_ngpus + grad_accu_step: int = args.global_batch_size // (args.micro_batch_size * scaling_factor) + + model_prefix = MODEL_ID_TO_PREFIX[args.model_id] + pas_config = { + 'recompute_modules': f'{model_prefix}DecoderLayer', + } + if args.mem_constraint > 0: pas_config['mem_constraint'] = args.mem_constraint + compute_config = ComputeConfig( + plan_ngpus=args.plan_ngpus, + trace_strategy=args.trace_strategy, + runtime_ngpus=runtime_ngpus, + constant_folding=True, + use_zero=True, + use_end2end=True, + pas_config=pas_config, + ) + + # --------------------------------- + ## Setup Dataset ## + dataset = load_from_disk(args.dataset_path) + tokenizer = get_tokenizer(args.model_id) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + def collate_fn(samples): + if len(samples) == 0: + return {} + + mini_batch = data_collator(samples) + _mini_batch = {} + + src_tokens = mini_batch.pop('input_ids') + seq_len = src_tokens.size(-1) + _mini_batch['src_tokens'] = src_tokens + + shift_labels = mini_batch['labels'][..., 1:] + _mini_batch['labels'] = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', IGNORE_IDX).contiguous() + + return { + "nsentences": len(samples), + "ntokens": len(samples) * seq_len, + "net_input": _mini_batch, + "target": _mini_batch.pop('labels'), + } + dataset_config = DatasetConfig( + type=(lambda split: dataset), + train_args={'split': 'train'}, + ) + dataloader_config = DataloaderConfig( + train_args={ + 'collate_fn': collate_fn, + 'drop_last': True, + }, + ) + sampler_config = DatasetSamplerConfig( + train_args={ + 'shuffle': True, + 'seed': args.seed, + }, + ) + + # --------------------------------- + # Model Config + model_args = build_model_args(args, train_attn_config) + model_config = ModelConfig( + type=ATTN_TO_MODEL[args.attn_type], + args=model_args, + ) + + # --------------------------------- + # optimizer hyperparameters are from YaRN + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={ + 'lr': 2e-5, + 'betas': (0.9, 0.95), + 'weight_decay': 0.0, + 'fused': True + }, + clip_gnorm=1.0, + loss_reduction='sum', + grad_reduction='per-token-mean', + aggregate_outputs_fn=aggregate_outputs_fn, + ) + + + # --------------------------------- + # Checkpoint Config + checkpoint_config = CheckpointConfig( + save_dir=args.ckpt_save_dir if args.ckpt_save_dir else f'./checkpoints_{args.name}', + every_n_epochs=args.ckpt_n_epoch, + every_n_train_steps=args.ckpt_n_step, + save_type='deduped', + resume_from=args.resume_from, + ) + + # --------------------------------- + # Log Config + log_config = LogConfig( + type=TensorBoardLogger, + args={ + 'name': args.name, + 'root_dir': args.tf_log_dir or f'./runs_{args.name}', + }, + ) + + # --------------------------------- + trainer_args = TrainerArgs( + global_batch_size=args.global_batch_size, + micro_batch_size=args.micro_batch_size, + grad_accumulation_steps=grad_accu_step, + + pas_policy='autodist', + precision='bf16', + seed=args.seed, + gen_reuse=args.reuse_type, + + gen_savedir=args.compile_save_path, + instance_name=args.name, + run_mode=args.run_mode, + max_epochs=args.n_epochs, + max_train_steps=args.n_iter, + enable_progress_bar=not args.disable_progressbar, + + compute_config=compute_config, + model=model_config, + optimizer=optimizer_config, + dataset=dataset_config, + dataloader=dataloader_config, + checkpoint=checkpoint_config, + log=[log_config], + broadcast_strategy='all', + dataset_sampler=sampler_config, + + transfer_config={ + "transfer_config_dir": args.transfer_config_dir, + "transfer_force": args.transfer_force, + }, + merged_ckpt_path=args.resume_merged_ckpt, + ) + + trainer = Trainer(train_args=trainer_args) + trainer.run() + +def print_args(args: argparse.Namespace): + print("=" * 80) + print(f"Start Experiment:\t{args.name}") + print(f"Seed:\t{args.seed}") + print(f"Reuse Type:\t{args.reuse_type}") + print(f"Run Mode:\t{args.run_mode}") + print(f"Total number of GPUs:\t{args.runtime_ngpus}") + print(f"GPU unit size:\t{args.plan_ngpus}") + print(f"Model ID:\t{args.model_id}") + + print('-' * 40) + if args.n_iter: + print(f"Number of Iterations:\t{args.n_iter} (number of tokens: {args.n_iter * args.global_batch_size * args.seq_len})") + else: + print(f"Number of Epochs:\t{args.n_epochs}") + + + print(f'Global Batch Size:\t{args.global_batch_size}') + print(f'Micro Batch Size:\t{args.micro_batch_size}') + + scaling_factor = args.runtime_ngpus // args.plan_ngpus + grad_accu_step = args.global_batch_size // (args.micro_batch_size * scaling_factor) + print(f"Scaling Factor (INFERRED):\t{scaling_factor}") + print(f"Gradient Accumulation Steps (INFERRED):\t{grad_accu_step}") + + print('-' * 40) + print(f"Model Config Path:\t{args.model_config_path}") + print(f"Dataset path:\t{args.dataset_path}") + print(f'Training Attention Config Path:\t{args.train_attn_config_path}') + print(f"Compile Save Path:\t{args.compile_save_path}") + print(f"Tensorboard Log Path:\t{args.tf_log_dir}") + print(f"Checkpoint Save Path:\t{args.ckpt_save_dir}") + print(f"Resume from Checkpoint:\t{args.check_resume}") + print(f"Path to the checkpoint to resume from:\t{args.resume_from}") + print(f"Path to the merged checkpoint to resume from:\t{args.resume_merged_ckpt}") + + print(f"Trace Strategy:\t{args.trace_strategy}") + if args.transfer_config_dir: + print(f"Transfer Configs from another experiment:\t{args.transfer_config_dir}") + print(f"Force Transfer Configs:\t{args.transfer_force}") + + if args.active_param_config_path: + print(f"Active Param Config Path:\t{args.active_param_config_path}") + + if args.ckpt_n_step: + print(f"Checkpoint Save Every {args.ckpt_n_step} Steps") + else: + print(f"Checkpoint Save Every {args.ckpt_n_epoch} Epochs") + print("=" * 80, flush=True) + +if __name__ == '__main__': + ## Parse Args ## + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=0, help='random seed') + parser.add_argument('--name', type=str, default='phi-grad', help='name of the experiment') + parser.add_argument('--seq_len', type=int, default=131072, help='sequence length') + parser.add_argument('--attn_type', type=str, default=AttnType.BASELINE, choices=AttnType.__dict__.values(), help='minference type') + parser.add_argument('--reuse_type', type=str, default='match', choices=['match', 'override', 'moo', 'graph'], help='reuse type') + parser.add_argument('--run_mode', type=str, default='run', choices=['run', 'compile'], help='run or compile') + parser.add_argument('--trace_strategy', type=str, default='cuda_run_cpu_offload', + choices=['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'], + help='trace strategy') + parser.add_argument('--plan_ngpus', type=int, required=True, help='specify the scale unit size') + parser.add_argument('--runtime_ngpus', type=int, required=True, help='specify the number of GPUs to use') + + parser.add_argument('--n_iter', type=int, default=0, help='Number of iterations') + parser.add_argument('--n_epochs', type=int, default=0, help='Number of epochs') + parser.add_argument('--nB_tokens', type=int, default=0, help='Number of tokens (in B) to process') + parser.add_argument('--global_batch_size', type=int, default=4, help='global batch size') + parser.add_argument('--micro_batch_size', type=int, default=1, help='micro batch size') + parser.add_argument('--mem_constraint', type=int, default=0, help='memory constraint') + + parser.add_argument('--model_id', type=str, default='microsoft/Phi-3-mini-4k-instruct', help='transformers model id') + parser.add_argument('--model_config_path', type=str, default=None, help='path to the model config') + + parser.add_argument('--train_attn_config_path', type=str, default=None, help='Name of Minference config file') + parser.add_argument('--compile_save_path', type=str, default='./.nnscaler', help='path to save compiled code') + + parser.add_argument('--tf_log_dir', type=str, default=None, help='path to save tensorboard logs') + parser.add_argument('--dataset_path', type=str, default=None, help='path to the dataset') + parser.add_argument('--check_resume', action='store_true', help='whether to resume from checkpoint') + parser.add_argument('--resume_from', type=str, default=None, help='path to the checkpoint to resume from') + parser.add_argument('--resume_merged_ckpt', type=str, default=None, help='path (dir) to the merged checkpoint to resume from') + + parser.add_argument('--ckpt_save_dir', type=str, default=None, help='path to save checkpoints') + parser.add_argument('--ckpt_n_epoch', type=int, default=1, help='save checkpoint every n epochs') + parser.add_argument('--ckpt_n_step', type=int, default=0, help='save checkpoint every n steps') + parser.add_argument('--transfer_config_dir', type=str, default="none", help='path to transfer configs from another experiment') + parser.add_argument('--transfer_force', action='store_true', help='force transfer configs') + parser.add_argument('--active_param_config_path', type=str, default=None, help='path to the active param list') + + parser.add_argument('-p', '--disable_progressbar', action='store_true', help='transformers model id',) + + args = parser.parse_args() + + # ------------------------------------------------- + # Preprocessing args + if args.ckpt_n_epoch <= 0: args.ckpt_n_epoch = None + if args.ckpt_n_step <= 0: args.ckpt_n_step = None + + if args.nB_tokens > 0: + args.n_iter = args.nB_tokens * 1e9 // args.global_batch_size // args.seq_len + 1 + args.n_epochs = 0 + if args.n_iter <= 0: args.n_iter = None + if args.n_epochs <= 0: args.n_epochs = None + + if args.transfer_config_dir.lower() == 'none': args.transfer_config_dir = None + + # set a new field of args 'args.orig_resume_from' to store the original resume_from value + args.orig_resume_from = args.resume_from + args.resume_from = get_resume_path( + args.check_resume, args.resume_from, args.ckpt_save_dir, args.runtime_ngpus + ) + + print_args(args) + main(args) \ No newline at end of file diff --git a/mtraining/trainer.py b/mtraining/trainer.py new file mode 100644 index 0000000..ed03cc1 --- /dev/null +++ b/mtraining/trainer.py @@ -0,0 +1,517 @@ +import os +import time +import copy +import torch +import logging +import pandas as pd +from tqdm import tqdm +from dataclasses import dataclass +from datetime import timedelta +from collections import defaultdict +from dataclasses import asdict +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Union, Optional, Callable + +import torch.distributed +from torch.profiler import profile, ProfilerActivity + +import nnscaler +from nnscaler.utils import accum_mode +from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.module import ParallelModule +from nnscaler.utils import is_running_distributed +from nnscaler.cli.trainer import ( + Trainer, _StepStat, TrainerArgs, TrainStatus, AggregatedTrainHook, TrainHook +) + +from mtraining.utils.paths import EXPR_DATA_SAVE_PATH +from mtraining.utils.general import fix_model_state_dict +from mtraining.utils.custom_parallel import parallelize as custom_parallelize + +logger = logging.getLogger(__name__) + +@dataclass +class CustomTrainerArgs(TrainerArgs): + transfer_config: Optional[Dict[str, Any]] = None + merged_ckpt_path: Optional[str] = None + +ITERATOR_COUNTER = defaultdict(int) +def get_iter_cnt(rank: int): + global ITERATOR_COUNTER + return ITERATOR_COUNTER.get(rank, 0) + +ITER_BATCH_IDX_DICT = {} +def get_iter_batch_idx(rank: int, iter_cnt: int): + global ITER_BATCH_IDX_DICT + return ITER_BATCH_IDX_DICT.get(rank, {}).get(iter_cnt, 0) + +def custom_train_step( + model: ParallelModule, + rank: int, iter_idx: int, + samples: List[Any], + is_dummy_batch: Optional[List[bool]] = None, + scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, +) -> List[Any]: + """ + The training step function. It should be called in the training loop. + Please note: + 1. This function is only supported in end2end mode. + 2. Gradient accumulation is done inside this function. + You shouldn't do gradient accumulation outside this function, + because the gradients will be cleared in the beginning of this function + Args: + samples (List[Any]): a list of samples. + if pipeline is used, it must have the same length as configured to pas policy + is_dummy_batch (Optional[List[bool]]): indicates whether the each micro-batch is dummya + scale_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): the function to scale the loss + Results: + List[Any]: a list of outputs for each sample + """ + global ITER_BATCH_IDX_DICT + model._warn_uninitialized_non_persistent_buffers(raise_error=True) + + if not model.compute_config.use_end2end: + raise RuntimeError("train_step() is only supported in end2end mode") + if is_dummy_batch and len(samples) != len(is_dummy_batch): + raise ValueError("The length of samples and is_dummy_batch should be the same") + + model._scale_loss(is_dummy_batch, scale_fn) + + # sync_grad will be done in _train_step + # so we never need to call it manually + model._sync_grad_required = False + sample_count = len(samples) + dataloader = microbatches(samples, cycle=False) + + if model.use_scheduler: + if len(samples) != model.nmicros_per_scheduler_step: + raise ValueError(f"Expected {model.nmicros_per_scheduler_step} samples, but got {sample_count}") + # only one step, so begin/end are both True + with accum_mode(begin=True, end=True): + return model._train_step(dataloader), None + else: + outputs = [] + latencies = [] + for idx in range(sample_count): + ITER_BATCH_IDX_DICT[rank][iter_idx] = idx + + sample_start_time = time.perf_counter() + with accum_mode(begin=(idx==0), end=(idx==sample_count-1)): + # loss, loss.data, samples['ntokens'], samples['nsentences'] + output = model._train_step(dataloader) + sample_time = time.perf_counter() - sample_start_time + latencies.append(sample_time) + + num_tokens = output[2] + if rank == 0: + print( + f"| {__name__} | rank={rank} | iter_idx={iter_idx}, batch_idx={idx}, loss={output[1] / num_tokens:.4f}," + f" num_tokens={num_tokens}, latency={sample_time:.4f}s" + ) + + outputs.append(output) + return outputs, latencies + +class CustomTrainer(Trainer): + def __init__( + self, + argv: Optional[List[str]] = None, + *, + train_args: Optional[Union[Dict[str, Any], CustomTrainerArgs]] = None, + ): + """ + Custom trainer with an additional parameter. + + Args: + argv (Optional[List[str]]): Command line arguments. If not specified, sys.argv[1:] will be used. + train_args: A dict used to construct TrainerArgs or a TrainerArgs object itself. + additional_param (Optional[Any]): Additional parameter for custom functionality. + """ + # Call the parent class's initializer with the existing parameters + super().__init__(argv=argv, train_args=train_args) + self.train_args: CustomTrainerArgs + + torch.distributed.init_process_group( + backend='nccl', + timeout=timedelta(hours=2), + ) + self.train_step_func = custom_train_step + + def _train_epoch(self, epoch): + VAL_STATUS_NO = 0 # not validated or saved + VAL_STATUS_VAL = 1 # validated but not saved + VAL_STATUS_SAVE = 2 # validated and saved + has_validated = VAL_STATUS_NO # 3 states + + resume_from_idx = self.train_status.finished_train_steps % self.total_train_steps_per_epoch + data_iter = enumerate(self._global_batch_iterator(num_skip_first=resume_from_idx)) + + max_epoch = self.max_train_steps // self.total_train_steps_per_epoch + if self.max_train_steps % self.total_train_steps_per_epoch != 0: + max_epoch += 1 + ndigits = len(str(max_epoch)) + epoch_format = f"0{ndigits}d" + epoch_desc = f'Epoch {format(epoch, epoch_format)}' + + if self.rank == 0: + progress = tqdm( + None, + total=self.total_train_steps_per_epoch, + initial=resume_from_idx, + desc=epoch_desc, + disable=not self.train_args.enable_progress_bar, + ) + else: + progress = None + + + # --------------------------------------------------------------------------------- + train_info_save_path = os.path.join(EXPR_DATA_SAVE_PATH['base_path'], 'train_info', f"epoch_{epoch}.log") + os.makedirs(os.path.dirname(train_info_save_path), exist_ok=True) + if self.rank == 0: + # Check whether the file already exists + # If it exists, assume existing log file has name 'epoch__.log' ('epoch_.log` is assumed to have num 0) + # Find the greatest for the current epoch and increment it to build the new file name + existing_files = [f for f in os.listdir(os.path.dirname(train_info_save_path)) \ + if f.startswith(f'epoch_{epoch}_') or f.startswith(f'epoch_{epoch}.log')] + if existing_files: + # Extract the numbers from the filenames + existing_nums = [int(f.split('_')[-1].split('.')[0]) for f in existing_files if f.startswith(f'epoch_{epoch}_')] + if not existing_nums: + existing_nums = [0] + new_num = max(existing_nums) + 1 + train_info_save_path = os.path.join(os.path.dirname(train_info_save_path), f'epoch_{epoch}_{new_num}.log') + else: + # If no existing files, use the original path + train_info_save_path = os.path.join(os.path.dirname(train_info_save_path), f'epoch_{epoch}.log') + with open(train_info_save_path, 'w') as f: f.write('') + + step_stat: Optional[_StepStat] = None + num_tokens_trained = 0 + for i, batches in data_iter: + idx = i + resume_from_idx + + global ITERATOR_COUNTER, ITER_BATCH_IDX_DICT + ITERATOR_COUNTER[self.rank] = idx + ITER_BATCH_IDX_DICT[self.rank] = {idx: 0} + if self.rank == 0: + progress.update(1) + step_start_at = time.perf_counter() + step_stat = _StepStat() + step_metrics = {} + has_validated = VAL_STATUS_NO + num_batches = len(batches) + batches, is_dummy_batch = self._fix_batches(batches) + + self.model.train() + + self.hook.before_zero_grad(self) + self.optimizer.zero_grad() + self.hook.after_zero_grad(self) + + self.hook.on_train_step_start(self, batches[:num_batches], idx) + losses, latencies = self.train_step_func(self.model, self.rank, idx, batches, is_dummy_batch) + self.hook.on_train_step_end(self, losses[:num_batches], batches[:num_batches], idx) + + aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs + aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) + if self.train_args.optimizer.loss_reduction == 'mean': + loss = aggregated_outputs.loss_sum / aggregated_outputs.num_batches + else: + loss = aggregated_outputs.loss_sum + step_stat.train_loss = loss + num_tokens_trained += aggregated_outputs.num_tokens + self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) + + self.hook.before_sync_grad(self) + self.optimizer.sync_shard_grad() + self.hook.after_sync_grad(self) + + # scale gradients + multiplier = self.train_args.scaling_factor + if self.train_args.optimizer.grad_reduction == 'sum': + # do nothing. `multiplier` is already correct + pass + elif self.train_args.optimizer.grad_reduction == 'mean': + if not aggregated_outputs.num_batches: + raise RuntimeError("`aggregate_outputs` doesn't set `num_batches` field") + multiplier /= aggregated_outputs.num_batches + else: + assert self.train_args.optimizer.grad_reduction == 'per-token-mean' + if not aggregated_outputs.num_tokens: + raise RuntimeError("`aggregate_outputs` doesn't set `num_tokens` field") + multiplier /= aggregated_outputs.num_tokens + self.optimizer.scale_grads(multiplier) + + # clip gradients + self.hook.before_gnorm_clip(self) + if self.train_args.optimizer.clip_gnorm: + step_stat.gnorm = self.optimizer.clip_gnorm(self.train_args.optimizer.clip_gnorm) + else: + step_stat.gnorm = self.optimizer.clip_gnorm() + self.hook.after_gnorm_clip(self, step_stat.gnorm) + step_stat.gnorm = step_stat.gnorm.item() + + # update parameters + step_stat.lr = self.optimizer.param_groups[0]['lr'] + self.hook.before_optimizer_step(self) + self.optimizer.step() + self.hook.after_optimizer_step(self) + if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': + self.lr_scheduler.step() + + self.train_status.finished_train_steps += 1 + self._log_mem_stats(tag='train') + step_metrics = {k:v for k, v in asdict(step_stat).items() if v is not None} + step_metrics['train_wall'] = time.perf_counter() - step_start_at + step_metrics['num_tokens_processed'] = num_tokens_trained + self.log_metrics(step_metrics, tag='train') + if self.rank == 0: + progress.set_postfix(step_metrics) + formatted_metrics = self._format_metrics(epoch_desc, idx + 1, step_metrics) + with open(train_info_save_path, 'a') as f: + f.write(f"{formatted_metrics}\n") + + if self.train_args.enable_log_progress \ + and self.train_status.finished_train_steps % self.train_args.log_progress_every_n_train_steps == 0: + + logger.info(formatted_metrics) + step_metrics = {} + + # validate and save checkpoint + if self.train_args.checkpoint.every_n_train_steps and \ + self.train_status.finished_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + + # max_train_steps is reached + if self.train_status.finished_train_steps >= self.max_train_steps: + if step_metrics and self.train_args.enable_log_progress: + logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) + step_metrics = {} + if not has_validated: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + if self.rank == 0: + # disable refresh the progress bar to avoid redundant progress bar + progress.leave = False + progress.close() + break + + if not has_validated and self.train_args.val_every_n_train_steps and \ + self.train_status.finished_train_steps % self.train_args.val_every_n_train_steps == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL + + # time.sleep(1) + else: + # Do per-epoch operations here. + # if the loop exits with `break` (max_train_steps is reached) + # those operations have done in the loop + if step_stat is None: + return # no train step runs. Nothing to do. + if has_validated < VAL_STATUS_SAVE \ + and self.train_args.checkpoint.every_n_epochs \ + and (epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + if not has_validated and self.train_args.val_every_n_epochs \ + and (epoch + 1) % self.train_args.val_every_n_epochs == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL + + def _setup(self): + self.train_args.init_env(self) + compile_only = self.train_args.compile_mode + + if is_running_distributed(): + nnscaler.init() + if torch.distributed.get_rank() == 0: + logging.getLogger().setLevel(logging.INFO) + else: + logging.getLogger().setLevel(logging.WARNING) + + def _create_model(): + model = self.train_args.create_model() + if self.train_args.param_dtype == self.train_args.buffer_dtype: + if self.train_args.param_dtype is not None: + model = model.to(self.train_args.param_dtype) + else: + # separate param and buffer dtype + # TODO: a little hacky. A better way? + # 3 kinds of tensors are converted in Module._apply: + # model parameters, its grad, and buffer + # param_dtype controls the first two, (but grad is `None` here) + # and buffer_dtype controls the last one + buf_ids = { id(buf) for buf in model.buffers(recurse=True) } + if self.train_args.param_dtype is not None: + model._apply( + lambda t: t.to(self.train_args.param_dtype) + if t.is_floating_point() and id(t) not in buf_ids + else t) + if self.train_args.buffer_dtype is not None: + model._apply( + lambda t: t.to(self.train_args.buffer_dtype) + if t.is_floating_point() and id(t) in buf_ids + else t) + if self.train_args.tracing_from_weights: + model.load_state_dict(torch.load(self.train_args.tracing_from_weights)) + return model + + # create dataset and dataloader + for stage in ['train', 'val', 'test']: + self.dataset[stage] = self.train_args.create_dataset(stage) + + # load a dummy input from training dataset + self.dummy_input = self._load_dummy_input() + self.dummy_input = self._fix_input(self.dummy_input) + + for stage in ['train', 'val', 'test']: + self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) + if self.dataloader[stage] is not None \ + and not self.dataloader[stage].drop_last \ + and len(self.dataset[stage]) % (self.train_args.micro_batch_size * self.train_args.scaling_factor) != 0: + warnings.warn( + f"Length of {stage} dataset ({len(self.dataset[stage])}) " + f"is not multiple of micro_batch_size * scale_factor ({self.train_args.micro_batch_size * self.train_args.scaling_factor}). " + f"In this case, the train_step for the last batch of samples can fail! " + f"You can specify `drop_last=True` in DataLoader to fix this problem." + ) + + # setup compute config + compute_config = copy.deepcopy(self.train_args.compute_config) + compute_config.pas_config['__pas_name'] = self.train_args.pas_policy + # autodist configs + compute_config.pas_config['update_freq'] = self.train_args.update_freq + compute_config.pas_config['use_bf16'] = self.train_args.param_dtype == torch.bfloat16 + compute_config.pas_config['use_fp16'] = self.train_args.param_dtype == torch.float16 + + compute_config.user_config['__from_trainer_args'] = { + 'mbs': self.train_args.micro_batch_size, + 'gbs': self.train_args.global_batch_size, + 'precision': self.train_args.precision, + 'model_args': self.train_args.model.args, + } + + # parallalize model + pmodel_class = custom_parallelize( + self.train_args.model_type, + self._create_dummy_forward_args(), + self.train_args.resolved_pas_policy, + compute_config, + module_fn=_create_model, + gen_savedir=self.train_args.gen_savedir, + reuse=self.train_args.gen_reuse, + instance_name=self.train_args.instance_name, + broadcast_strategy=self.train_args.broadcast_strategy, + load_module=not compile_only, + transfer_config=self.train_args.transfer_config, + ) + if compile_only: + return + + torch.distributed.barrier() + self.rank = torch.distributed.get_rank() + + self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq + if len(self.dataloader['train']) % self.train_args.update_freq != 0: + self.total_train_steps_per_epoch += 1 # will add extra dummy batches + + if self.train_args.max_epochs and self.train_args.max_train_steps: + self.max_train_steps = min( + self.total_train_steps_per_epoch * self.train_args.max_epochs, + self.train_args.max_train_steps + ) + elif self.train_args.max_train_steps: + self.max_train_steps = self.train_args.max_train_steps + else: + assert self.train_args.max_epochs, "max_epochs or max_train_steps should be specified" + self.max_train_steps = self.total_train_steps_per_epoch * self.train_args.max_epochs + + _, self.sync_group = self.train_args.compute_config.get_sync_group() + self.model = pmodel_class() + self.model.cuda() + self.optimizer = self.train_args.create_parallel_optimizer(self.model) + + def reducer_pre_hook(reducer, grad): + grad.div_(self.train_args.scaling_factor) + self.optimizer.register_reducer_pre_hook(reducer_pre_hook) + self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) + self.loggers = self.train_args.create_loggers() + + supported_hook_components = [ + self.model, + self.optimizer, + self.lr_scheduler, + ] + self.hook = AggregatedTrainHook( + [x for x in supported_hook_components if isinstance(x, TrainHook)] + + [self.train_args.create_hook()] + ) + + self._log_config(self.train_args.to_dict()) + self._load_checkpoint() + + if self.train_args.merged_ckpt_path is not None: + print(f"Rank {self.rank} | {__name__} | loading merged checkpoint from {self.train_args.merged_ckpt_path}") + merged_ckpt_path = os.path.join(self.train_args.merged_ckpt_path, "pytorch_model.bin") + model_state_dict = torch.load(merged_ckpt_path, map_location='cpu') + + first_key = list(model_state_dict.keys())[0] + if len(first_key.split('.')) == 1: + # For Ring-Attention models, the merged checkpoint is directly copied from one of the shards and has different key names. + model_state_dict = fix_model_state_dict(self.model, model_state_dict) + + first_key = list(model_state_dict.keys())[0] + if 'model.model' not in first_key: + # Our merging logic also removes the prefix `model.` from the state dict keys when saving + model_state_dict = {'model.' + k: v for k, v in model_state_dict.items()} + if self.rank % int(os.getenv("GPU_PER_NODE", "8")) == 0: + print(f"Rank {self.rank} | {__name__} | loaded model state dict.keys(): {model_state_dict.keys()}") + + # in our merge program, `model` is poped out and we directly pass the model_state_dict instead of model_state_dict['model'] + nnscaler.load_merged_state_dict( + self.model, model_state_dict, + self.optimizer, None, + ) + + self.hook.after_setup(self) + + def _load_checkpoint(self): + resume_from = self.train_args.checkpoint.get_resume_checkpoint_dir() + if not resume_from: + return + logger.info(f"Resuming from {resume_from}") + if resume_from.is_file(): + resume_from = resume_from # when we load from merged checkpoint + else: + resume_from = resume_from / f'{self.rank}.ckpt' + state_dict = torch.load(resume_from, map_location='cpu') + self.hook.on_load_checkpoint(self, state_dict) + ckpt_save_type = state_dict['train_args']['checkpoint']['save_type'] + + if ckpt_save_type == 'merged': # it is a merged state dict + nnscaler.load_merged_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + elif ckpt_save_type == 'sharded': + nnscaler.load_sharded_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + elif ckpt_save_type == 'deduped': + nnscaler.load_deduped_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + else: + raise ValueError(f"Unknown checkpoint type: {ckpt_save_type}") + + if 'lr_scheduler' in state_dict: + if state_dict['lr_scheduler'] and not self.lr_scheduler: + raise ValueError("lr_scheduler is not set in the current trainer") + if self.lr_scheduler: + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + self.train_status = TrainStatus(**state_dict['train_status']) + self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() \ No newline at end of file diff --git a/mtraining/utils/__init__.py b/mtraining/utils/__init__.py new file mode 100644 index 0000000..fc13228 --- /dev/null +++ b/mtraining/utils/__init__.py @@ -0,0 +1,2 @@ +from .loss import chunk_linear_cross_entropy, linear_cross_entropy +from .general import * \ No newline at end of file diff --git a/mtraining/utils/custom_parallel.py b/mtraining/utils/custom_parallel.py new file mode 100644 index 0000000..cd3d4bf --- /dev/null +++ b/mtraining/utils/custom_parallel.py @@ -0,0 +1,467 @@ +import os +import torch +import shutil +import inspect +import torch.distributed as dist + +from pathlib import Path +from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List, Set, Literal + +from nnscaler.graph import IRGraph +from nnscaler.graph.parser import FxModuleParser +from nnscaler.runtime.device import DeviceGroup +from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState + +from nnscaler.parallel import ( + ComputeConfig, ReuseType, BroadcastGenFilesStrategy, RegenStatus, + _prepare_namespace, _compile_flags, _clean_files, _is_any_gencode_loaded, _gencode, + _broadcast_gen_files, _load_parallel_module_class, + _GENCODE_FILE_TEMPLATE, _GRAPH_DUMP_FILE, _FORWARD_ARGS_DUMP_FILE, _PREDEFINED_POLICIES + +) + +import logging +logger = logging.getLogger(__name__) + +def compute_config_safe_equals(a: Optional['ComputeConfig'], b: Optional['ComputeConfig']) -> bool: + """ + Return False if a and b are from incompatible version of ComputeConfig + This is only for backward compatibility, and will be removed in future + and can use `==` when we save dict version of ComputeConfig to file. + """ + res = True + try: + for key in a.__dataclass_fields__: + if getattr(a, key) != getattr(b, key): + print(f"{key} not equal: {getattr(a, key)} (old_config) != {getattr(b, key)} (current_config)") + + if key == "user_config": + continue + else: + print(f"compute_config_safe_equals | {key} not equal: {getattr(a, key)} (old_config) != {getattr(b, key)} (current_config)") + res = False + return res + except AttributeError: + logger.warning("Failed to compare ComputeConfig. They are incompatible.") + return False + +GRAPH_CONFIG_FIELDS = ['constant_folding', 'user_config', 'inference_only', 'end2end_mode', 'trace_strategy'] +def graph_config_equals(a: Dict[str, Any], b: Dict[str, Any]) -> bool: + """ + Return False if a and b are from incompatible version of ComputeConfig + This is only for backward compatibility, and will be removed in future + and can use `==` when we save dict version of ComputeConfig to file. + """ + res = True + try: + for key in GRAPH_CONFIG_FIELDS: + if getattr(a, key) != getattr(b, key): + print(f"graph_config_equals | {key} not equal: {getattr(a, key)} (old_config) != {getattr(b, key)} (current_config)") + if key != "user_config": + res = False + return res + except AttributeError: + logger.warning("Failed to compare GraphConfig. They are incompatible.") + return False + + +TRACE_FILE_EXTENSIONS = [ + FxModuleParser.ATTR_CONTENT_FILE_0, # init weights file(fullmodel.pt.*), + FxModuleParser.ATTR_MAP_FILE, # param name mapping (dist_param_map.pt)\ + _GRAPH_DUMP_FILE, # graph dump (graph.ckp), + _FORWARD_ARGS_DUMP_FILE, # forward args dump(forward_args.pkl), + ParallelModule.ORIGIN_MODULE_METADATA_FILE # origin module metadata (origin_module_metadata.pt), +] +def transfer_metadata(out_dir, transfer_config: Dict[str, Any]): + transfer_config_dir, transfer_force = transfer_config['transfer_config_dir'], transfer_config['transfer_force'] + if not os.path.exists(transfer_config_dir): + # if transfer_config_dir is not set, use the default directory + transfer_config_dir = transfer_config_dir.replace("compile_config/", "compile_config/rank_0/") + assert os.path.exists(transfer_config_dir), f"Source directory {transfer_config_dir} for transferring does not exist" + + # transfer files in src_dir with postfix not being .py by executing `cp` + print(f"{__name__} | Transfering files from {transfer_config_dir} to {out_dir} (local_rank={os.getenv('LOCAL_RANK')})") + for file in os.listdir(transfer_config_dir): + if file in TRACE_FILE_EXTENSIONS or file.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM): + src_file = os.path.join(transfer_config_dir, file) + dst_file = os.path.join(out_dir, file) + + print(f"{__name__} | Copying {src_file} to {dst_file} (local_rank={os.getenv('LOCAL_RANK')})" ) + if not os.path.exists(dst_file) or transfer_force: + shutil.copyfile(src_file, dst_file) + + if not os.path.exists(dst_file): + raise FileNotFoundError(f"{__name__} | Copy failed ({dst_file} does not exist after copying)") + + # Create a file 'transferred.sign' to indicate that the transfer is done + with open(os.path.join(out_dir, "transferred.sign"), 'w') as f: + f.write("Transferred from " + transfer_config_dir) + +def _prepare_and_check_reusable( + gen_savedir: str, + module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], + compute_config: ComputeConfig, + instance_name: Optional[str] = None, + reuse: ReuseType = ReuseType.MATCH, + transfer_config: Dict[str, Any] = None, + ) -> Tuple[str, bool, bool]: + """ + Prepare the output directory for code generation, and also check if the existing code is reusable. + + Args: + gen_savedir (str): the directory to save generated code + module_or_module_class (Union[Type[torch.nn.Module], torch.nn.Module]): the original module or module class + compute_config (ComputeConfig): the environment resource + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. + reuse (ReuseType): specify which part can be reused. + + Returns: + Tuple[str, bool]: the output directory and whether the existing code is reusable. + + Raises: + RuntimeError: if the existing code is not reusable, + will raise RuntimeError if the code is not reusable but the module is already loaded. + """ + namespace, outdir = _prepare_namespace(gen_savedir, module_or_module_class, instance_name) + reusable = False + transferred = False + + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + + # Empty + Transfer -> config match, graph match, tracing file present -> generate code by MATCH or MOO + # Empty w.o. Transfer -> Empty -> generate code by MATCH or MOO + has_transferred = os.path.exists(os.path.join(outdir, "transferred.sign")) + if transfer_config is not None and transfer_config.get("transfer_config_dir", None) is not None \ + and (not has_transferred or transfer_config['transfer_force']): + # transfer_config_dir: Optional[str] = None, + transfer_metadata(outdir, transfer_config) + ComputeConfig.safe_dump_to_file(compute_config, config_file) + transferred = True + + # decision matrix for code generation + # reuse flag | dir condition(imported, empty, match, unmatched) | action + # --------------------------------------------------------- + # OVERRIDE | empty | generate + # OVERRIDE | imported | raise error + # OVERRIDE | whatever match | generate + # OVERRIDE | unmatch | generate + # GRAPH | empty | generate + # GRAPH | imported | raise error + # GRAPH | graph match | reuse graph, and regenerate code + # GRAPH | all match | reuse graph, and regenerate code + # GRAPH | unmatch | generate + # MATCH | empty | generate + # MATCH | match | reuse(do nothing) + # MATCH* | whatever unmatch| raise error (except when there's no python source code, see below) + # MATCH | imported | doesn't matter + # MOO | empty | generate + # MOO | match | reuse(do nothing) + # MOO | match graph | reuse graph, and regenerate code + # MOO | imported | raise error if whatever unmatch + # *: The precondition for `except` part is the compute config should match. + # you can take it as a continous operation after a failed generation. + old_config: Optional[ComputeConfig] = ComputeConfig.safe_load_from_file(config_file) + is_config_match = compute_config_safe_equals(old_config, compute_config) + # is_graph_config_match = old_config is not None and old_config.graph_config == compute_config.graph_config + is_graph_config_match = old_config is not None and graph_config_equals(old_config.graph_config, compute_config.graph_config) + trace_meta_files = [ + outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # init weights file(fullmodel.pt.*), + outdir / FxModuleParser.ATTR_MAP_FILE, # param name mapping (dist_param_map.pt) + ] + + if reuse == ReuseType.MATCH or reuse == ReuseType.MOO: + # check if the module is already generated + expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] + expected_output_files.extend(trace_meta_files) + expected_output_files.append(config_file) + expected_output_files.append(outdir / _GRAPH_DUMP_FILE) # graph dump (graph.ckp), + expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) # forward args dump(forward_args.pkl), + expected_output_files.append(outdir / ParallelModule.ORIGIN_MODULE_METADATA_FILE) # origin module metadata (origin_module_metadata.pt), + existing_output_files = [ + f for f in outdir.glob('*') + if f.is_file() and ( # just take fullmodel.pt.0 to compare + not f.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) + or f.name == FxModuleParser.ATTR_CONTENT_FILE_0 + ) and not f.name.endswith('.sign') + ] + + print(f"{__name__} | compute config match: {is_config_match}") + print(f"{__name__} | graph config match: {is_graph_config_match}") + print(f"{__name__} | existing output files: {existing_output_files}") + print(f"{__name__} | expected output files: {expected_output_files}") + + if existing_output_files: # if the directory is not empty + if is_config_match \ + and all([output_file.exists() for output_file in expected_output_files]) \ + and len(existing_output_files) == len(expected_output_files): + + print(f"{__name__} | Reuse existing files in {outdir}") + reusable = True # everything is matched. + elif is_config_match \ + and all(f.suffix != '.py' for f in existing_output_files): + # No python source code is generated. + # which means its last generation failed. + # in this case, we can reuse the same directory safely. + logger.info(f'Output directory {outdir} is not empty. ' + f'But no python source code is present. ' + f'Will reuse the directory and the graph dump if present.') + # we have to trace the graph again if not all meta files are present. + print(f"{__name__} | compute config match but no python code exists in {outdir}") + if not all([meta_file.exists() for meta_file in trace_meta_files]): + print(f"{__name__} | compute config match but no python code exists in {outdir} and not all meta files are present") + _clean_files(outdir) + elif reuse == ReuseType.MATCH: + raise RuntimeError(f'Output directory {outdir} is not empty. ' + f'And the existing files do not match with current config. ' + f'You can remove the directory and try again, ' + f'or set reuse to ReuseType.NONE/ReuseType.OVERRIDE to regenerate the code.') + else: + assert reuse == ReuseType.MOO + if _is_any_gencode_loaded(namespace): + raise RuntimeError(f'Output directory {outdir} is already loaded. ' + f'You can not override a loaded module.') + elif is_graph_config_match: + # reuse the graph dump + print(f"{__name__} | MOO | graph match -> reuse graph but clean the current code") + _clean_files(outdir, '*.py') + else: + _clean_files(outdir) + else: + # check if the module is already loaded + if _is_any_gencode_loaded(namespace): + raise RuntimeError(f'Output directory {outdir} is already loaded. ' + f'You can not override a loaded module.') + # clear existing generated files + if reuse == ReuseType.OVERRIDE \ + or not is_graph_config_match \ + or not all([meta_file.exists() for meta_file in trace_meta_files]): + # we have to trace the graph again if not all meta files are present even when reuse=graph. + print(f"{__name__} | OVERRIDE | Override existing files in {outdir}") + glob_pattern = '*' + else: + print(f"{__name__} | GRAPH | keep the graph dump in {outdir} and regenerate the code") + glob_pattern = '*.py' # so we can keep graph dumps. + _clean_files(outdir, glob_pattern) + + return outdir, reusable, transferred + + +def parallelize( + module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], + dummy_forward_args: Dict[str, Any], + pas_policy: Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]], + compute_config: ComputeConfig, + *, + gen_savedir: Union[str, Path] = './.nnscaler', + reuse: Union[ReuseType, str] = ReuseType.MATCH, + instance_name: Optional[str] = None, + load_module: bool = True, + module_dtype: Optional[torch.dtype] = None, + module_fn: Optional[Callable[[], torch.nn.Module]] = None, + init_module_params: bool = True, + broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', + transfer_config: Optional[Dict[str, Any]] = None, +) -> Union[None, ParallelModule, Type[ParallelModule]]: + """ + Convert a torch.nn.Module object or class to ParallelModule object or class. + + If you want to save multiple instances of the same module, + you can specify the instance_name to distinguish them. + + Currently you must use a shared file system to share the generated files (like mounted Azure Blob) + Or you can unset load_module flag, and manually copy the generated files to other nodes. + After all nodes have the generated files, you can call parallelize() again with load_module flag set. + + Note: if reuse is not set to ReuseType.MATCH, + the generated code in outdir will be removed EVEN IF the code generation fails in this call. + + if the input is a module object. + * The module object will be copied to cpu to handle possible insufficient gpu memory. + * The training flag will be the same as the original module + + This function can be used to convert both module object and module class to parallel module or parallel module class. + Among key-value arguments, + module_fn and module_dtype control how to create the module object. + whereas init_module_params controls how to load parallel module object after conversion is done. + + 1. If the input is a module object, it will return a ParallelModule object if load_module is True. + This is useful when the module is created by a factory function. + + a. module_fn is ignored. + b. module_dtype is used to control the dtype of the input module. + c. init_module_params is used to control whether to initialize the parallel module parameters when load it. + + 2. If the input is a module class, it will return a ParallelModule sub class if load_module is True. + + a. module_fn is used to create the module object, or module's__init__ if not prent. + b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). + Of course, it can be merged into module_fn. + c. init_module_params is ignored. + + After the module is converted, you can use it to create module object by calling it like a module class. + The module class is defined like: + + :: + + class GenModule(nnscaler.runtime.module.ParallelModule): + def __init__(self, init_params=True): + super().__init__() + ... + ... + + So you can use `init_params` in `__init__` to control whether to initialize the module parameters. + For example, if you don't want to initialize module params: + + :: + + module = GenModule(init_params=False) + + Args: + module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled + dummy_forward_args (Dict[str, Any]): the dummy input for the module forward + pas_policy (Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]): the pas policy, + it can be a name of builtin policies, or a custom policy function. + compute_config (ComputeConfig): the environment resource + reuse (ReuseType): specify which part can be reused. + gen_savedir (Union[str, Path]): the directory to save generated code + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. + load_module (bool): whether to load the generated module or module class after conversion is done. + init_module_params (bool): If true, when we construct the module, all its parameters are initialized with the same value with when we traced. + Otherwise, they will be empty tensor. + This parameter will be passed to the module constructor, + so it is only used when module_or_module_class is a module object, and load_module is true. + module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. + module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. + broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. + Please note that the broadcasting will only be done in torchrun environment, + and will throw an error if dist is not initialized and broadcast_strategy is not NONE. + Returns: + Union[ParallelModule, Type[ParallelModule], None]: + if load_module flag is set, return the converted ParallelModule object or class + if load_module flag is not set, return None + """ + if ( + isinstance(module_or_module_class, ParallelModule) or + (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, ParallelModule)) + ): + # already done + return module_or_module_class if load_module else None + + if ( + isinstance(module_or_module_class, CubeModule) or + (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, CubeModule)) + ): + raise RuntimeError("Old style CubeModule is not supported") + + if isinstance(pas_policy, str): + if not pas_policy in _PREDEFINED_POLICIES: + raise ValueError(f"Invalid pas_policy: {pas_policy}") + pas_policy = _PREDEFINED_POLICIES[pas_policy] + + is_module_class = inspect.isclass(module_or_module_class) + module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ + reuse = ReuseType(reuse) if isinstance(reuse, str) else reuse + broadcast_strategy = BroadcastGenFilesStrategy(broadcast_strategy) if isinstance(broadcast_strategy, str) else broadcast_strategy + + # Call it here just to ensure the device group is initialized. + # If the user initializes dist + # and doesn't call `nnscaler.init()` before calling this function, this is necessary. + if dist.is_initialized(): + _ = DeviceGroup() + + # generate code only in node0 + # if it is not in a torchrun environment, just generate. + if not dist.is_initialized() or dist.get_rank() == 0: + outdir, reusable, transferred = _prepare_and_check_reusable( + gen_savedir, module_class, compute_config, instance_name, reuse, + transfer_config + ) + if not reusable: + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + ComputeConfig.safe_dump_to_file(compute_config, config_file) # always refresh compute config + with _compile_flags(compute_config): + regen_status = _gencode( + module_or_module_class, + dummy_forward_args, + pas_policy, + compute_config, + outdir, + module_dtype=module_dtype, + module_fn=module_fn, + ) + else: + regen_status = RegenStatus.NONE + logger.info(f"Reuse generated code in {outdir}") + + if regen_status == RegenStatus.CODE and transferred: + regen_status = RegenStatus.ALL + + if dist.is_initialized(): + # code generation can take very long time (for example, over 1 hour) + # It is not always OK to use dist.barrier() directly. + # because the default timeout for nccl is 30 minutes + # (we can't control the timeout setting if dist is not initialized by us) + DeviceGroup().long_barrier() + + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + if not dist.is_initialized(): # we only support loading in torchrun environment + raise RuntimeError("Broadcast generated files failed: dist is not initialized.") + dist.barrier() + # sync regen_status + curr_rank = dist.get_rank() + if curr_rank == 0: + sent_obj = [regen_status] + else: + sent_obj = [None] + dist.broadcast_object_list( + sent_obj, + src=0, + ) + if curr_rank != 0: + regen_status = sent_obj[0] + + # narrow down broadcast_strategy according to regen_status + if regen_status == RegenStatus.NONE: + # we don't need to broadcast anything + broadcast_strategy = BroadcastGenFilesStrategy.NONE + elif regen_status == RegenStatus.CODE: + # narrow ALL/NO_WEIGHTS down to code + broadcast_strategy = BroadcastGenFilesStrategy.CODE + else: + # we don't need to narrow broadcast_strategy in this case + # keep the original broadcast_strategy + assert regen_status == RegenStatus.ALL + + # broadcast generated files according to regen_status + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + _broadcast_gen_files( + module_class, + gen_savedir=gen_savedir, + instance_name=instance_name, + broadcast_strategy=broadcast_strategy, + ) + elif os.getenv("FORCE_BROADCAST") == "1": + # force broadcast generated files + print(f"Force broadcast generated files in {gen_savedir}") + _broadcast_gen_files( + module_class, + gen_savedir=gen_savedir, + instance_name=instance_name, + broadcast_strategy=BroadcastGenFilesStrategy.ALL, + ) + + if load_module: + if not dist.is_initialized(): # we only support loading in torchrun environment + raise RuntimeError("Load ParallelModule failed: dist is not initialized.") + dist.barrier() + parallel_module_class = _load_parallel_module_class( + module_class, + gen_savedir=gen_savedir, + instance_name=instance_name, + ) + if is_module_class: + return parallel_module_class + else: + parallel_module = parallel_module_class(init_module_params) + parallel_module.train(module_or_module_class.training) # set training state to the same as original module + return parallel_module diff --git a/mtraining/utils/data_utils/__init__.py b/mtraining/utils/data_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtraining/utils/data_utils/bookcorpus.py b/mtraining/utils/data_utils/bookcorpus.py new file mode 100644 index 0000000..69429a5 --- /dev/null +++ b/mtraining/utils/data_utils/bookcorpus.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy +import torch +import argparse +from typing import List, Dict +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer, PreTrainedTokenizer + +def get_tokenizer(model_path): + return AutoTokenizer.from_pretrained(model_path) + +BOS_TOKEN = '' + +def tokenize(sample: Dict[str, str], tokenizer: PreTrainedTokenizer, text_key: str): + input_ids = tokenizer.encode(BOS_TOKEN + sample[text_key] + tokenizer.eos_token, add_special_tokens=False) + return {"input_ids": input_ids} + +def concate_split(samples: Dict[str, List[List[int]]], sample_len: int, text_key: str): + buffer = samples[text_key][0] + resized_ids = [] + length = [] + for in_ids in samples[text_key]: + buffer.extend(in_ids) + while len(buffer) >= sample_len: + resized_ids.append(buffer[:sample_len]) + length.append(sample_len) + buffer = buffer[sample_len:] + return {"input_ids": resized_ids, "length": length} + +def create_dataset(tokenizer: PreTrainedTokenizer, raw_dataset: Dataset, text_key: str, sample_len: int = 8 * 1024, batch_size=10000): + tokenized_dataset = raw_dataset.map( + tokenize, remove_columns=raw_dataset.column_names, num_proc=32, + fn_kwargs={'tokenizer': tokenizer, 'text_key': text_key} + ) + return tokenized_dataset.map( + concate_split, remove_columns=tokenized_dataset.column_names, + num_proc=32, batched=True, + batch_size=batch_size, fn_kwargs={'sample_len': sample_len, 'text_key': 'input_ids'} + ) + + +def modify_bos_token(tokenizer: PreTrainedTokenizer): + # https://huggingface.co/Qwen/Qwen2-7B-Instruct/discussions/15 + global BOS_TOKEN + if tokenizer.bos_token is None: + BOS_TOKEN = "<|endoftext|>" + else: + BOS_TOKEN = tokenizer.bos_token + +if __name__ == '__main__': + # python bookcorpus.py --data_path_or_name "bookcorpus/bookcorpus" --tokenizer_path_or_name "meta-llama/Llama-2-7b-hf" --save_path "bookcorpus-llama2-2k-hf" --sequence_length 2048 + parser = argparse.ArgumentParser() + parser.add_argument('--data_path_or_name', help='the path or name of the raw dataset, for exmaple, "bookcorpus/bookcorpus"', type=str, required=True) + parser.add_argument('--tokenizer_path_or_name', help='the tokenizer path or name, for example, "meta-llama/Llama-2-7b-hf"', type=str, required=True) + parser.add_argument('--save_path', help='the path to save the tokenized dataset', type=str, required=True) + parser.add_argument('--sequence_length', help='the length of each sample in the tokenized dataset, usually set to the max sequence length', type=int, required=True) + args = parser.parse_args() + + data_path_or_name = args.data_path_or_name + tokenizer_path_or_name = args.tokenizer_path_or_name + save_path = args.save_path + sequence_length = args.sequence_length + + raw_dataset = load_dataset(data_path_or_name)["train"] + tokenizer = get_tokenizer(tokenizer_path_or_name) + modify_bos_token(tokenizer) + + dataset = create_dataset(tokenizer, raw_dataset, "text", sequence_length) + dataset.save_to_disk(save_path) diff --git a/mtraining/utils/data_utils/prolong.py b/mtraining/utils/data_utils/prolong.py new file mode 100644 index 0000000..83672b9 --- /dev/null +++ b/mtraining/utils/data_utils/prolong.py @@ -0,0 +1,122 @@ +import os +import logging +import argparse + +from tqdm import tqdm +from typing import Dict +from streaming import StreamingDataset +from transformers import PreTrainedTokenizer +from datasets import Dataset, concatenate_datasets + +from mtraining.utils.general import get_tokenizer +# ------------------------------------------------ + +logger = logging.getLogger(__name__) + +LLAMA3_MODEL_ID = "meta-llama/Meta-Llama-3-8B" +LLAMA_TOKENZIER = None + +def tokenize(sample: Dict[str, str], tokenizer: PreTrainedTokenizer, seq_len: int=524288): + text = sample['text'] + for token_k, token_v in LLAMA_TOKENZIER.special_tokens_map.items(): + if token_k in tokenizer.special_tokens_map: + text = text.replace(token_v, tokenizer.special_tokens_map[token_k]) + + input_ids = tokenizer.encode( + text, + add_special_tokens=False, + truncation=True, + max_length=seq_len + ) + return {"input_ids": input_ids, "length": len(input_ids)} + + +DOMAINS = [ + "thestackv1_concat_by_repo-524288@0.15", + "thestackv1_concat_by_repo-65536@0.15", + "book-524288@0.05", + "book-65536@0.25", + "fineweb-edu@0.1", + "fineweb-2023-50@0.1", + "stackexchange@0.04", + "dolmawiki@0.04", + "tuluv2@0.03", + "arxiv@0.03", + "openwebmath@0.03", + "textbooks@0.03", +] +FIXED_512K = [ + "thestackv1_concat_by_repo-524288", + "book-524288" +] + +DOMAIN_MIX_DICT = { + "full": DOMAINS, + "fixed_524288": FIXED_512K +} + +def main(args): + global LLAMA_TOKENZIER + + seq_len = args.sequence_length + LLAMA_TOKENZIER = get_tokenizer(LLAMA3_MODEL_ID) + model_tokenizer = get_tokenizer(args.model_id) + if model_tokenizer.bos_token is None: + model_tokenizer.bos_token = "<|endoftext|>" + + domains = DOMAIN_MIX_DICT[args.dataset_mix] + dataset_paths = [os.path.join(args.dataset_path, domain) for domain in domains] + tokenized_datasets = [] + for idx, dataset_path in enumerate(dataset_paths): + print('-' * 50) + print(f"Processing {domains[idx]} from {dataset_path}...") + texts = [] + dataset = StreamingDataset( + local=dataset_path, + remote=None, + shuffle=False, + batch_size=1, + ) + + for ix, sample in tqdm(enumerate(dataset)): + if ix % args.sample_interval != 0: continue + + sample_input_ids = sample["input_ids"] + sample_splits = [sample_input_ids[i:i+seq_len] for i in range(0, len(sample_input_ids), seq_len)] + for sample_split in sample_splits: + # De-tokenization + text = LLAMA_TOKENZIER.decode(sample_split) + texts.append(text) + + hf_dataset = Dataset.from_dict( + { + "text": texts + } + ) + + tokenized_dataset = hf_dataset.map( + tokenize, + remove_columns=hf_dataset.column_names, + num_proc=64, + fn_kwargs={'tokenizer': model_tokenizer, 'seq_len': seq_len} + ) + tokenized_datasets.append(tokenized_dataset) + + print('-' * 50) + print(f"Concatenating and Saving tokenized datasets to {args.save_path}...") + concat_dataset = concatenate_datasets(tokenized_datasets) + filtered_concat_dataset = concat_dataset.filter(lambda x: x['length'] == seq_len, num_proc=128) + filtered_concat_dataset.save_to_disk(args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, default="microsoft/Phi-3-mini-4k-instruct") + parser.add_argument('--dataset_mix', type=str, default="fixed_524288") + parser.add_argument('--dataset_path', type=str) + parser.add_argument('--save_path', type=str) + parser.add_argument('--sequence_length', type=int, default=524288) + parser.add_argument("--sample_interval", type=int, default=1) + args = parser.parse_args() + + main(args) diff --git a/mtraining/utils/expr_data.py b/mtraining/utils/expr_data.py new file mode 100644 index 0000000..e82f025 --- /dev/null +++ b/mtraining/utils/expr_data.py @@ -0,0 +1,15 @@ +import os +from dataclasses import dataclass + +@dataclass +class ExprData: + global_batch_size: int + micro_batch_size: int + reuse_type: str + +EXPR_DATA = ExprData(64, 1, "match") + +def update_expr_data(args): + global EXPR_DATA + EXPR_DATA = ExprData(args.global_batch_size, args.micro_batch_size, args.reuse_type) + # print(f"Updated global batch size to {EXPR_DATA.global_batch_size}, micro batch size to {EXPR_DATA.micro_batch_size}, reuse type to {EXPR_DATA.reuse_type}") diff --git a/mtraining/utils/general.py b/mtraining/utils/general.py new file mode 100644 index 0000000..955a38e --- /dev/null +++ b/mtraining/utils/general.py @@ -0,0 +1,200 @@ +import os +import torch +import torch.distributed as dist + +from typing import List +from transformers import AutoModelForCausalLM, AutoTokenizer + +from nnscaler.cli.trainer_args import AggregatedOutputs +from nnscaler.runtime.module import ParallelModule + +from .paths import BASE_DIR + +import logging +logger = logging.getLogger(__name__) + +def get_tokenizer(tokenizer_name_or_path, + model_max_length=None, + default_bos_token="", + default_eos_token="", + default_pad_token="[PAD]", + default_unk_token=""): + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True) + special_tokens_dict = dict() + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = default_pad_token + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = default_eos_token + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = default_bos_token + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = default_unk_token + + tokenizer.add_special_tokens(special_tokens_dict) + if model_max_length: + tokenizer.model_max_length = model_max_length + return tokenizer + +def get_module_path(model_id: str): + model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) + module_path = str(model.__class__.__module__) + del model + + return module_path + +def aggregate_outputs_fn(loss_outputs, sync_group) -> AggregatedOutputs: + losses, ntokens_info = [], [] + for _, loss, ntokens, _ in loss_outputs: + losses.append(loss) + ntokens_info.append(ntokens) + + + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + dist.all_reduce(loss_sum, group=sync_group) + + ntokens_sum = torch.sum(torch.tensor(ntokens_info, dtype=torch.float64, device=torch.cuda.current_device())) + dist.all_reduce(ntokens_sum, group=sync_group) + + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + dist.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum=loss_sum.item() / ntokens_sum.item(), + num_batches=num_batches.item(), + num_tokens=ntokens_sum.item(), + ) + + +def load_comm_profile_data(args): + if args.plan_ngpus in [2, 4, 8, 16]: + logger.info(f"Use nnscaler's built-in communication profiling data for {args.plan_ngpus} GPUs") + return + + from nnscaler.autodist.util import get_default_profile_path + profile_dir = os.path.join(get_default_profile_path(), 'comm') + profile_path = os.path.join(profile_dir, f"intra_{args.plan_ngpus}.json") + + if not os.path.exists(profile_path): + import shutil + logger.info(f"Communication profiling data not found in {profile_dir} for {args.plan_ngpus} GPUs. Use built-in communication profiling data (collected on A100-SXM4-40GB)") + src_file_path = os.path.join(BASE_DIR, "utils/comm_prof/NVIDIA_A100-SXM4-40GB", f"intra_{args.plan_ngpus}.json") + if not os.path.exists(src_file_path): + raise FileNotFoundError(f"Communication profiling data not found in {src_file_path} nor in nnscaler's built-in library for {args.plan_ngpus} GPUs") + os.makedirs(profile_dir, exist_ok=True) + + num_dev = 2 + while num_dev <= args.plan_ngpus: + src_file_path = os.path.join(BASE_DIR, "utils/comm_prof/NVIDIA_A100-SXM4-40GB", f"intra_{num_dev}.json") + profile_path = os.path.join(profile_dir, f"intra_{num_dev}.json") + if os.path.exists(profile_path): + logger.info(f"Communication profiling data already exists in {profile_path} for {num_dev} GPUs") + num_dev *= 2 + continue + else: + logger.info(f"Copying {src_file_path} to {profile_path}") + shutil.copy(src_file_path, profile_path) + num_dev *= 2 + + + +def is_active(module_name: str, keep_active: List[str]): + for active_module_subname in keep_active: + if active_module_subname.lower() in module_name.lower(): + return True + return False + +def freeze_model_params_(model, keep_active: List[str], prefix=""): + if dist.get_rank() == 0: + print("-" * 80) + print(f"Only keeping parameters with substring in {keep_active} active...") + + for name, module in model._modules.items(): + if len(list(module.children())) > 0: + freeze_model_params_(module, keep_active, prefix + name + ".") + else: + param_name = prefix + name + if not is_active(param_name, keep_active): + print(f"Freezing {param_name}...") + for param in module.parameters(): + param.requires_grad = False + else: + print(f"Keeping {param_name} active...") + + if dist.get_rank() == 0: + print("-" * 80) + + +def freeze_model_params(model, active_param_config_path: str, prefix=""): + with open(active_param_config_path, "r") as f: + keep_active = f.read().splitlines() + print(f"freeze_model_params | keep active: {keep_active}") + + freeze_model_params_(model, keep_active, prefix) + +def get_resume_path( + check_resume: bool, + resume_from: str, + ckpt_save_dir: str, + num_gpus: int, +): + if not check_resume: + return None + elif resume_from is not None: + return resume_from + + # Detect the last checkpoint in CKPT_PATH + ckpt_dirs = [ckpt_dir for ckpt_dir in os.listdir(ckpt_save_dir) if len(ckpt_dir.split('-')) == 2 and ckpt_dir.split('-')[0].isdigit()] + + # Filter out directories that do not contain number of ckpts (file name ending with .ckpt) equal to num_gpus (check by os.listdir) + filtered_ckpt_dirs = [] + for ckpt_dir in ckpt_dirs: + ckpt_dir_path = os.path.join(ckpt_save_dir, ckpt_dir) + if os.path.isdir(ckpt_dir_path): + ckpt_files = [f for f in os.listdir(ckpt_dir_path) if f.endswith('.ckpt')] + if len(ckpt_files) == num_gpus: + filtered_ckpt_dirs.append(ckpt_dir) + + print(f"get_resume_path | filtered_ckpt_dirs = {filtered_ckpt_dirs}") + if len(filtered_ckpt_dirs) == 0: + return None + + target_ckpt_dir = sorted(filtered_ckpt_dirs, key=lambda x: (int(x.split('-')[0]), int(x.split('-')[1])))[-1] + target_ckpt_dir = os.path.join(ckpt_save_dir, target_ckpt_dir) + print(f"get_resume_path | target_ckpt_dir = {target_ckpt_dir}") + return target_ckpt_dir + + + +def fix_model_state_dict(model, model_state_dict): + if isinstance(model, ParallelModule): + required_keys = list(model.dist_param_map.values()) + required_keys_under = {k[6:]: v for k, v in model.dist_param_map.items()} + else: + required_keys = model.state_dict().keys() + required_keys_under = {k.replace('.', '_'): k for k in required_keys} + + has_model_prefix = 'model' in model_state_dict + model_state_dict = model_state_dict if not has_model_prefix else model_state_dict['model'] + model_sd_copy = model_state_dict.copy() + + if dist.is_initialized() and dist.get_rank() % int(os.getenv("GPU_PER_NODE", "8")) == 0: + print(f"{__name__} | required_keys[:10]: {required_keys[:10]}") + print(f"{__name__} | required_keys_under: {required_keys_under}") + print(f"{__name__} | model_state_dict.keys()[:10]: {list(model_state_dict.keys())[:10]}") + + for k in model_state_dict.keys(): + model_sd_copy.pop(k) + + under_k_start = 0 if not has_model_prefix else 1 + under_k = '_'.join(k.split('_')[under_k_start:-1]) + + if under_k in required_keys_under: + model_sd_copy[required_keys_under[under_k]] = model_state_dict[k] + + if 'lm_head_weight' in required_keys_under: + for k in model_state_dict.keys(): + if 'model_embed_tokens_weight' in k: + model_sd_copy[required_keys_under['lm_head_weight']] = model_state_dict[k] + + return model_sd_copy diff --git a/mtraining/utils/loss.py b/mtraining/utils/loss.py new file mode 100644 index 0000000..5fd3a54 --- /dev/null +++ b/mtraining/utils/loss.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.utils.checkpoint as ckpt + +from nnscaler.graph.parser.register import register_op + + +def linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: + """ + Compute the cross entropy loss of a linear layer. + + Args: + + x: [token_num, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [token_num], the target token index + padding_idx: int, the index of padding token + + Returns: + + losses: [token_num], the cross entropy loss of each token + """ + logits = torch.nn.functional.linear(x, w) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + losses = torch.nn.functional.nll_loss(normalized_logits, y, reduction='none', ignore_index=padding_idx) + return losses + + +def chunk_linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor: + """ + In order to reduce the memory usage when the sequence length and dictionary size are large, we can split the input + tensor into chunks and compute the cross entropy loss of each chunk separately. + You can register this function with annotation 'b l d^, n^ d^, b l -> b l'. + + Args: + + x: [bsz, seq_len, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [bsz, seq_len], the target token index + padding_idx: int, the index of padding token + chunk_size: int, the size of each chunk + + Returns: + + losses: [bsz, seq_len], the cross entropy loss of each token + """ + bsz, seq_len, hidden_size = x.size() + token_num = bsz * seq_len + x = x.view(token_num, hidden_size) + y = y.view(token_num) + + if token_num % chunk_size != 0: + raise ValueError(f"token_num {token_num} is not divisible by chunk_size {chunk_size}") + + chunk_num = token_num // chunk_size + xs = x.view(chunk_num, chunk_size, hidden_size) + ys = y.view(chunk_num, chunk_size) + losses = [] + for i in range(chunk_num): + loss = ckpt.checkpoint(linear_cross_entropy, xs[i], w, ys[i], padding_idx, use_reentrant=False) + losses.append(loss) + losses = torch.stack(losses).view(bsz, seq_len) + return losses + + +register_op('b l d^, n^ d^, b l -> b l')(chunk_linear_cross_entropy) diff --git a/mtraining/utils/paths.py b/mtraining/utils/paths.py new file mode 100644 index 0000000..716e7d4 --- /dev/null +++ b/mtraining/utils/paths.py @@ -0,0 +1,27 @@ +import os + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +EXPR_DATA_SAVE_PATH = { + 'base_path': None, + "ckpt_save_path": None, + "compile_save_path": None, +} + +def update_expr_data_by_base_path(base_path): + EXPR_DATA_SAVE_PATH['base_path'] = base_path + +def update_expr_data_save_path( + ckpt_save_path, + compile_save_path, + ): + if ckpt_save_path is None: + EXPR_DATA_SAVE_PATH['base_path'] = os.getenv("EFFI_EXPR_STORE_DIR") + else: + EXPR_DATA_SAVE_PATH['base_path'] = os.path.dirname(ckpt_save_path) + if "rank_" in ckpt_save_path: + EXPR_DATA_SAVE_PATH['base_path'] = os.path.dirname(os.path.dirname(ckpt_save_path)) + + EXPR_DATA_SAVE_PATH['ckpt_save_path'] = ckpt_save_path + EXPR_DATA_SAVE_PATH['compile_save_path'] = compile_save_path \ No newline at end of file