@@ -29,8 +29,26 @@ namespace cp_algo::math::fft {
2929 }
3030 }
3131
32+ [[gnu::target(" avx2" )]] static std::pair<vftype, vftype>
33+ do_split (auto const & a, size_t idx, size_t n, u64x4 mul) {
34+ if (idx >= std::size (a)) {
35+ return std::pair{vftype (), vftype ()};
36+ }
37+ u64x4 au = {
38+ idx < std::size (a) ? a[idx].getr () : 0 ,
39+ idx + 1 < std::size (a) ? a[idx + 1 ].getr () : 0 ,
40+ idx + 2 < std::size (a) ? a[idx + 2 ].getr () : 0 ,
41+ idx + 3 < std::size (a) ? a[idx + 3 ].getr () : 0
42+ };
43+ au = montgomery_mul (au, mul, mod, imod);
44+ au = au >= base::mod () ? au - base::mod () : au;
45+ auto ai = to_double (i64x4 (au >= base::mod () / 2 ? au - base::mod () : au));
46+ auto quo = round (ai / split ());
47+ return std::pair{ai - quo * split (), quo};
48+ }
49+
3250 dft (size_t n): A(n), B(n) {init ();}
33- dft (auto const & a, size_t n, bool partial = true ): A(n), B(n) {
51+ [[gnu::target( " avx2 " )]] dft(auto const & a, size_t n, bool partial = true ): A(n), B(n) {
3452 init ();
3553 base b2x32 = bpow (base (2 ), 32 );
3654 u64x4 cur = {
@@ -42,24 +60,8 @@ namespace cp_algo::math::fft {
4260 u64x4 step4 = u64x4{} + (bpow (factor, 4 ) * b2x32).getr ();
4361 u64x4 stepn = u64x4{} + (bpow (factor, n) * b2x32).getr ();
4462 for (size_t i = 0 ; i < std::min (n, std::size (a)); i += flen) {
45- auto splt = [&](size_t i, auto mul) {
46- if (i >= std::size (a)) {
47- return std::pair{vftype (), vftype ()};
48- }
49- u64x4 au = {
50- i < std::size (a) ? a[i].getr () : 0 ,
51- i + 1 < std::size (a) ? a[i + 1 ].getr () : 0 ,
52- i + 2 < std::size (a) ? a[i + 2 ].getr () : 0 ,
53- i + 3 < std::size (a) ? a[i + 3 ].getr () : 0
54- };
55- au = montgomery_mul (au, mul, mod, imod);
56- au = au >= base::mod () ? au - base::mod () : au;
57- auto ai = to_double (i64x4 (au >= base::mod () / 2 ? au - base::mod () : au));
58- auto quo = round (ai / split ());
59- return std::pair{ai - quo * split (), quo};
60- };
61- auto [rai, qai] = splt (i, cur);
62- auto [rani, qani] = splt (n + i, montgomery_mul (cur, stepn, mod, imod));
63+ auto [rai, qai] = do_split (a, i, n, cur);
64+ auto [rani, qani] = do_split (a, n + i, n, montgomery_mul (cur, stepn, mod, imod));
6365 A.at (i) = vpoint (rai, rani);
6466 B.at (i) = vpoint (qai, qani);
6567 cur = montgomery_mul (cur, step4, mod, imod);
@@ -75,8 +77,23 @@ namespace cp_algo::math::fft {
7577 }
7678 }
7779 }
80+ [[gnu::target(" avx2" )]] static void do_dot_iter (size_t i, point rt, vpoint& Cv, vpoint& Dv, vpoint const & Av, vpoint const & Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
81+ AC += Av * Cv; AD += Av * Dv;
82+ BC += Bv * Cv; BD += Bv * Dv;
83+ real (Cv) = rotate_right (real (Cv));
84+ imag (Cv) = rotate_right (imag (Cv));
85+ real (Dv) = rotate_right (real (Dv));
86+ imag (Dv) = rotate_right (imag (Dv));
87+ auto cx = real (Cv)[0 ], cy = imag (Cv)[0 ];
88+ auto dx = real (Dv)[0 ], dy = imag (Dv)[0 ];
89+ real (Cv)[0 ] = cx * real (rt) - cy * imag (rt);
90+ imag (Cv)[0 ] = cx * imag (rt) + cy * real (rt);
91+ real (Dv)[0 ] = dx * real (rt) - dy * imag (rt);
92+ imag (Dv)[0 ] = dx * imag (rt) + dy * real (rt);
93+ }
94+
7895 template <bool overwrite = true , bool partial = true >
79- void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
96+ [[gnu::target( " avx2 " )]] void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
8097 cvector::exec_on_evals<1 >(A.size () / flen, [&](size_t k, point rt) {
8198 k *= flen;
8299 vpoint AC, AD, BC, BD;
@@ -87,18 +104,7 @@ namespace cp_algo::math::fft {
87104 auto [Bx, By] = B.at (k);
88105 for (size_t i = 0 ; i < flen; i++) {
89106 vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
90- AC += Av * Cv; AD += Av * Dv;
91- BC += Bv * Cv; BD += Bv * Dv;
92- real (Cv) = rotate_right (real (Cv));
93- imag (Cv) = rotate_right (imag (Cv));
94- real (Dv) = rotate_right (real (Dv));
95- imag (Dv) = rotate_right (imag (Dv));
96- auto cx = real (Cv)[0 ], cy = imag (Cv)[0 ];
97- auto dx = real (Dv)[0 ], dy = imag (Dv)[0 ];
98- real (Cv)[0 ] = cx * real (rt) - cy * imag (rt);
99- imag (Cv)[0 ] = cx * imag (rt) + cy * real (rt);
100- real (Dv)[0 ] = dx * real (rt) - dy * imag (rt);
101- imag (Dv)[0 ] = dx * imag (rt) + dy * real (rt);
107+ do_dot_iter (i, rt, Cv, Dv, Av, Bv, AC, AD, BC, BD);
102108 }
103109 } else {
104110 AC = A.at (k) * Cv;
@@ -123,7 +129,18 @@ namespace cp_algo::math::fft {
123129 dot (C, D, A, B, C);
124130 }
125131
126- void recover_mod (auto &&C, auto &res, size_t k) {
132+ [[gnu::target(" avx2" )]] static void do_recover_iter (size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
133+ auto A0 = lround (A), A1 = lround (C), A2 = lround (B);
134+ auto Ai = A0 + A1 * split () + A2 * splitsplit + uint64_t (base::modmod ());
135+ auto Au = montgomery_reduce (u64x4 (Ai), mod, imod);
136+ Au = montgomery_mul (Au, mul, mod, imod);
137+ Au = Au >= base::mod () ? Au - base::mod () : Au;
138+ for (size_t j = 0 ; j < flen; j++) {
139+ res[idx + j].setr (typename base::UInt (Au[j]));
140+ }
141+ }
142+
143+ [[gnu::target(" avx2" )]] void recover_mod (auto &&C, auto &res, size_t k) {
127144 size_t check = (k + flen - 1 ) / flen * flen;
128145 assert (res.size () >= check);
129146 size_t n = A.size ();
@@ -142,26 +159,16 @@ namespace cp_algo::math::fft {
142159 auto [Ax, Ay] = A.at (i);
143160 auto [Bx, By] = B.at (i);
144161 auto [Cx, Cy] = C.at (i);
145- auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
146- auto A0 = lround (A), A1 = lround (C), A2 = lround (B);
147- auto Ai = A0 + A1 * split () + A2 * splitsplit + uint64_t (base::modmod ());
148- auto Au = montgomery_reduce (u64x4 (Ai), mod, imod);
149- Au = montgomery_mul (Au, mul, mod, imod);
150- Au = Au >= base::mod () ? Au - base::mod () : Au;
151- for (size_t j = 0 ; j < flen; j++) {
152- res[i + j].setr (typename base::UInt (Au[j]));
153- }
154- };
155- set_i (i, Ax, Bx, Cx, cur);
162+ do_recover_iter (i, Ax, Bx, Cx, cur, splitsplit, res);
156163 if (i + n < k) {
157- set_i (i + n, Ay, By, Cy, montgomery_mul (cur, stepn, mod, imod));
164+ do_recover_iter (i + n, Ay, By, Cy, montgomery_mul (cur, stepn, mod, imod), splitsplit, res );
158165 }
159166 cur = montgomery_mul (cur, step4, mod, imod);
160167 }
161168 checkpoint (" recover mod" );
162169 }
163170
164- void mul (auto &&C, auto const & D, auto &res, size_t k) {
171+ [[gnu::target( " avx2 " )]] void mul (auto &&C, auto const & D, auto &res, size_t k) {
165172 assert (A.size () == C.size ());
166173 size_t n = A.size ();
167174 if (!n) {
@@ -174,10 +181,10 @@ namespace cp_algo::math::fft {
174181 C.ifft ();
175182 recover_mod (C, res, k);
176183 }
177- void mul_inplace (auto &&B, auto & res, size_t k) {
184+ [[gnu::target( " avx2 " )]] void mul_inplace (auto &&B, auto & res, size_t k) {
178185 mul (B.A , B.B , res, k);
179186 }
180- void mul (auto const & B, auto & res, size_t k) {
187+ [[gnu::target( " avx2 " )]] void mul (auto const & B, auto & res, size_t k) {
181188 mul (cvector (B.A ), B.B , res, k);
182189 }
183190 std::vector<base, big_alloc<base>> operator *= (dft &B) {
@@ -240,7 +247,7 @@ namespace cp_algo::math::fft {
240247 }
241248
242249 // store mod x^n-k in first half, x^n+k in second half
243- void mod_split (auto &&x, size_t n, auto k) {
250+ [[gnu::target( " avx2 " )]] void mod_split (auto &&x, size_t n, auto k) {
244251 using base = std::decay_t <decltype (k)>;
245252 dft<base>::init ();
246253 assert (std::size (x) == 2 * n);
0 commit comments