@@ -29,7 +29,7 @@ namespace cp_algo::math::fft {
2929 }
3030 }
3131
32- [[gnu::target( " avx2 " )]] static std::pair<vftype, vftype>
32+ simd_target static std::pair<vftype, vftype>
3333 do_split (auto const & a, size_t idx, u64x4 mul) {
3434 if (idx >= std::size (a)) {
3535 return std::pair{vftype (), vftype ()};
@@ -48,7 +48,7 @@ namespace cp_algo::math::fft {
4848 }
4949
5050 dft (size_t n): A(n), B(n) {init ();}
51- [[gnu::target( " avx2 " )]] dft(auto const & a, size_t n, bool partial = true ): A(n), B(n) {
51+ simd_target dft (auto const & a, size_t n, bool partial = true ): A(n), B(n) {
5252 init ();
5353 base b2x32 = bpow (base (2 ), 32 );
5454 u64x4 cur = {
@@ -77,7 +77,7 @@ namespace cp_algo::math::fft {
7777 }
7878 }
7979 }
80- [[gnu::target( " avx2 " )]] static void do_dot_iter (point rt, vpoint& Cv, vpoint& Dv, vpoint const & Av, vpoint const & Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
80+ simd_target static void do_dot_iter (point rt, vpoint& Cv, vpoint& Dv, vpoint const & Av, vpoint const & Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
8181 AC += Av * Cv; AD += Av * Dv;
8282 BC += Bv * Cv; BD += Bv * Dv;
8383 real (Cv) = rotate_right (real (Cv));
@@ -93,7 +93,7 @@ namespace cp_algo::math::fft {
9393 }
9494
9595 template <bool overwrite = true , bool partial = true >
96- [[gnu::target( " avx2 " )]] void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
96+ simd_target void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
9797 cvector::exec_on_evals<1 >(A.size () / flen, [&](size_t k, point rt) {
9898 k *= flen;
9999 vpoint AC, AD, BC, BD;
@@ -129,7 +129,7 @@ namespace cp_algo::math::fft {
129129 dot (C, D, A, B, C);
130130 }
131131
132- [[gnu::target( " avx2 " )]] static void do_recover_iter (size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
132+ simd_target static void do_recover_iter (size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
133133 auto A0 = lround (A), A1 = lround (C), A2 = lround (B);
134134 auto Ai = A0 + A1 * split () + A2 * splitsplit + uint64_t (base::modmod ());
135135 auto Au = montgomery_reduce (u64x4 (Ai), mod, imod);
@@ -140,7 +140,7 @@ namespace cp_algo::math::fft {
140140 }
141141 }
142142
143- [[gnu::target( " avx2 " )]] void recover_mod (auto &&C, auto &res, size_t k) {
143+ simd_target void recover_mod (auto &&C, auto &res, size_t k) {
144144 size_t check = (k + flen - 1 ) / flen * flen;
145145 assert (res.size () >= check);
146146 size_t n = A.size ();
@@ -168,7 +168,7 @@ namespace cp_algo::math::fft {
168168 checkpoint (" recover mod" );
169169 }
170170
171- [[gnu::target( " avx2 " )]] void mul (auto &&C, auto const & D, auto &res, size_t k) {
171+ simd_target void mul (auto &&C, auto const & D, auto &res, size_t k) {
172172 assert (A.size () == C.size ());
173173 size_t n = A.size ();
174174 if (!n) {
@@ -181,10 +181,10 @@ namespace cp_algo::math::fft {
181181 C.ifft ();
182182 recover_mod (C, res, k);
183183 }
184- [[gnu::target( " avx2 " )]] void mul_inplace (auto &&B, auto & res, size_t k) {
184+ simd_target void mul_inplace (auto &&B, auto & res, size_t k) {
185185 mul (B.A , B.B , res, k);
186186 }
187- [[gnu::target( " avx2 " )]] void mul (auto const & B, auto & res, size_t k) {
187+ simd_target void mul (auto const & B, auto & res, size_t k) {
188188 mul (cvector (B.A ), B.B , res, k);
189189 }
190190 big_vector<base> operator *= (dft &B) {
@@ -247,7 +247,7 @@ namespace cp_algo::math::fft {
247247 }
248248
249249 // store mod x^n-k in first half, x^n+k in second half
250- [[gnu::target( " avx2 " )]] void mod_split (auto &&x, size_t n, auto k) {
250+ simd_target void mod_split (auto &&x, size_t n, auto k) {
251251 using base = std::decay_t <decltype (k)>;
252252 dft<base>::init ();
253253 assert (std::size (x) == 2 * n);
0 commit comments