Skip to content

Commit 0949e24

Browse files
committed
Put avx2 target on convolution code
1 parent 51aa5b2 commit 0949e24

File tree

5 files changed

+251
-157
lines changed

5 files changed

+251
-157
lines changed

cp-algo/math/convolution.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#ifndef CP_ALGO_MATH_CONVOLUTION_HPP
2+
#define CP_ALGO_MATH_CONVOLUTION_HPP
3+
#include "fft.hpp"
4+
#include "cvector.hpp"
5+
#include <vector>
6+
#include <algorithm>
7+
#include <bit>
8+
#include <type_traits>
9+
#include <ranges>
10+
11+
namespace cp_algo::math {
12+
13+
// Convolution limited to the first `need` coefficients.
14+
// Writes the result into `a`; performs in-place when possible (modint path).
15+
template<class VecA, class VecB>
16+
void convolution_prefix(VecA& a, VecB const& b, size_t need) {
17+
using T = typename std::decay_t<VecA>::value_type;
18+
size_t na = std::min(need, std::size(a));
19+
size_t nb = std::min(need, std::size(b));
20+
a.resize(na);
21+
auto bv = b | std::views::take(nb);
22+
23+
if(na == 0 || nb == 0) {
24+
a.clear();
25+
return;
26+
}
27+
28+
if constexpr (modint_type<T>) {
29+
// Use NTT-based truncated multiplication. Works in-place on `a`.
30+
fft::mul_truncate(a, bv, need);
31+
} else if constexpr (std::is_same_v<T, fft::point>) {
32+
size_t conv_len = na + nb - 1;
33+
size_t n = std::bit_ceil(conv_len);
34+
n = std::max(n, (size_t)fft::flen);
35+
fft::cvector A(n), B(n);
36+
for(size_t i = 0; i < na; i++) {
37+
A.set(i, a[i]);
38+
}
39+
for(size_t i = 0; i < nb; i++) {
40+
B.set(i, bv[i]);
41+
}
42+
A.fft();
43+
B.fft();
44+
A.dot(B);
45+
A.ifft();
46+
a.assign(need, T(0));
47+
for(size_t i = 0; i < std::min(need, conv_len); i++) {
48+
a[i] = A.template get<fft::point>(i);
49+
}
50+
} else if constexpr (std::is_same_v<T, fft::ftype>) {
51+
// Imaginary-cyclic convolution modulo x^n-i to compute acyclic convolution
52+
// Represents polynomials as point(a[i], a[i+n]) to work in x^n-i basis
53+
size_t conv_len = na + nb - 1;
54+
size_t n = std::bit_ceil(conv_len) / 2;
55+
n = std::max(n, (size_t)fft::flen);
56+
57+
fft::cvector A(n), B(n);
58+
// Pack as modulo x^n-i: A[i] = point(a[i], a[i+n])
59+
for(size_t i = 0; i < std::min(n, na); i++) {
60+
fft::ftype re = a[i], im = 0;
61+
if(i + n < na) im = a[i + n];
62+
A.set(i, fft::point(re, im));
63+
}
64+
for(size_t i = 0; i < std::min(n, nb); i++) {
65+
fft::ftype re = bv[i], im = 0;
66+
if(i + n < nb) im = bv[i + n];
67+
B.set(i, fft::point(re, im));
68+
}
69+
A.fft();
70+
B.fft();
71+
A.dot(B);
72+
A.ifft();
73+
a.assign(2 * n, T(0));
74+
for(size_t i = 0; i < n; i++) {
75+
auto v = A.template get<fft::point>(i);
76+
a[i] = v.real();
77+
a[i + n] = v.imag();
78+
}
79+
a.resize(need);
80+
} else {
81+
// Generic fallback: use simple O(n^2) convolution from fft utilities.
82+
fft::mul_slow(a, bv, need);
83+
}
84+
}
85+
86+
} // namespace cp_algo::math
87+
88+
#endif // CP_ALGO_MATH_CONVOLUTION_HPP

cp-algo/math/cvector.hpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace cp_algo::math::fft {
1515
using point = complex<ftype>;
1616
using vpoint = complex<vftype>;
1717
static constexpr vftype vz = {};
18-
vpoint vi(vpoint const& r) {
18+
[[gnu::target("avx2")]] vpoint vi(vpoint const& r) {
1919
return {-imag(r), real(r)};
2020
}
2121

@@ -39,7 +39,7 @@ namespace cp_algo::math::fft {
3939
}
4040
}
4141
template<class pt = point>
42-
pt get(size_t k) const {
42+
[[gnu::target("avx2")]] pt get(size_t k) const {
4343
if constexpr(std::is_same_v<pt, point>) {
4444
return {real(r[k / flen])[k % flen], imag(r[k / flen])[k % flen]};
4545
} else {
@@ -79,31 +79,36 @@ namespace cp_algo::math::fft {
7979
return roots[std::bit_width(n)];
8080
}
8181
template<int step>
82-
static void exec_on_eval(size_t n, size_t k, auto &&callback) {
82+
[[gnu::target("avx2")]] static void exec_on_eval(size_t n, size_t k, auto &&callback) {
8383
callback(k, root(4 * step * n) * eval_point(step * k));
8484
}
8585
template<int step>
86-
static void exec_on_evals(size_t n, auto &&callback) {
86+
[[gnu::target("avx2")]] static void exec_on_evals(size_t n, auto &&callback) {
8787
point factor = root(4 * step * n);
8888
for(size_t i = 0; i < n; i++) {
8989
callback(i, factor * eval_point(step * i));
9090
}
9191
}
9292

93-
void dot(cvector const& t) {
93+
[[gnu::target("avx2")]] static void do_dot_iter(size_t i, point rt, vpoint& Bv, vpoint const& Av, vpoint& res) {
94+
res += Av * Bv;
95+
real(Bv) = rotate_right(real(Bv));
96+
imag(Bv) = rotate_right(imag(Bv));
97+
auto x = real(Bv)[0], y = imag(Bv)[0];
98+
real(Bv)[0] = x * real(rt) - y * imag(rt);
99+
imag(Bv)[0] = x * imag(rt) + y * real(rt);
100+
}
101+
102+
[[gnu::target("avx2")]] void dot(cvector const& t) {
94103
size_t n = this->size();
95104
exec_on_evals<1>(n / flen, [&](size_t k, point rt) {
96105
k *= flen;
97106
auto [Ax, Ay] = at(k);
98107
auto Bv = t.at(k);
99108
vpoint res = vz;
100109
for (size_t i = 0; i < flen; i++) {
101-
res += vpoint(vz + Ax[i], vz + Ay[i]) * Bv;
102-
real(Bv) = rotate_right(real(Bv));
103-
imag(Bv) = rotate_right(imag(Bv));
104-
auto x = real(Bv)[0], y = imag(Bv)[0];
105-
real(Bv)[0] = x * real(rt) - y * imag(rt);
106-
imag(Bv)[0] = x * imag(rt) + y * real(rt);
110+
vpoint Av = vpoint(vz + Ax[i], vz + Ay[i]);
111+
do_dot_iter(i, rt, Bv, Av, res);
107112
}
108113
set(k, res);
109114
});

cp-algo/math/fft.hpp

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,26 @@ namespace cp_algo::math::fft {
2929
}
3030
}
3131

32+
[[gnu::target("avx2")]] static std::pair<vftype, vftype>
33+
do_split(auto const& a, size_t idx, size_t n, u64x4 mul) {
34+
if(idx >= std::size(a)) {
35+
return std::pair{vftype(), vftype()};
36+
}
37+
u64x4 au = {
38+
idx < std::size(a) ? a[idx].getr() : 0,
39+
idx + 1 < std::size(a) ? a[idx + 1].getr() : 0,
40+
idx + 2 < std::size(a) ? a[idx + 2].getr() : 0,
41+
idx + 3 < std::size(a) ? a[idx + 3].getr() : 0
42+
};
43+
au = montgomery_mul(au, mul, mod, imod);
44+
au = au >= base::mod() ? au - base::mod() : au;
45+
auto ai = to_double(i64x4(au >= base::mod() / 2 ? au - base::mod() : au));
46+
auto quo = round(ai / split());
47+
return std::pair{ai - quo * split(), quo};
48+
}
49+
3250
dft(size_t n): A(n), B(n) {init();}
33-
dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
51+
[[gnu::target("avx2")]] dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
3452
init();
3553
base b2x32 = bpow(base(2), 32);
3654
u64x4 cur = {
@@ -42,24 +60,8 @@ namespace cp_algo::math::fft {
4260
u64x4 step4 = u64x4{} + (bpow(factor, 4) * b2x32).getr();
4361
u64x4 stepn = u64x4{} + (bpow(factor, n) * b2x32).getr();
4462
for(size_t i = 0; i < std::min(n, std::size(a)); i += flen) {
45-
auto splt = [&](size_t i, auto mul) {
46-
if(i >= std::size(a)) {
47-
return std::pair{vftype(), vftype()};
48-
}
49-
u64x4 au = {
50-
i < std::size(a) ? a[i].getr() : 0,
51-
i + 1 < std::size(a) ? a[i + 1].getr() : 0,
52-
i + 2 < std::size(a) ? a[i + 2].getr() : 0,
53-
i + 3 < std::size(a) ? a[i + 3].getr() : 0
54-
};
55-
au = montgomery_mul(au, mul, mod, imod);
56-
au = au >= base::mod() ? au - base::mod() : au;
57-
auto ai = to_double(i64x4(au >= base::mod() / 2 ? au - base::mod() : au));
58-
auto quo = round(ai / split());
59-
return std::pair{ai - quo * split(), quo};
60-
};
61-
auto [rai, qai] = splt(i, cur);
62-
auto [rani, qani] = splt(n + i, montgomery_mul(cur, stepn, mod, imod));
63+
auto [rai, qai] = do_split(a, i, n, cur);
64+
auto [rani, qani] = do_split(a, n + i, n, montgomery_mul(cur, stepn, mod, imod));
6365
A.at(i) = vpoint(rai, rani);
6466
B.at(i) = vpoint(qai, qani);
6567
cur = montgomery_mul(cur, step4, mod, imod);
@@ -75,8 +77,23 @@ namespace cp_algo::math::fft {
7577
}
7678
}
7779
}
80+
[[gnu::target("avx2")]] static void do_dot_iter(size_t i, point rt, vpoint& Cv, vpoint& Dv, vpoint const& Av, vpoint const& Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
81+
AC += Av * Cv; AD += Av * Dv;
82+
BC += Bv * Cv; BD += Bv * Dv;
83+
real(Cv) = rotate_right(real(Cv));
84+
imag(Cv) = rotate_right(imag(Cv));
85+
real(Dv) = rotate_right(real(Dv));
86+
imag(Dv) = rotate_right(imag(Dv));
87+
auto cx = real(Cv)[0], cy = imag(Cv)[0];
88+
auto dx = real(Dv)[0], dy = imag(Dv)[0];
89+
real(Cv)[0] = cx * real(rt) - cy * imag(rt);
90+
imag(Cv)[0] = cx * imag(rt) + cy * real(rt);
91+
real(Dv)[0] = dx * real(rt) - dy * imag(rt);
92+
imag(Dv)[0] = dx * imag(rt) + dy * real(rt);
93+
}
94+
7895
template<bool overwrite = true, bool partial = true>
79-
void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
96+
[[gnu::target("avx2")]] void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
8097
cvector::exec_on_evals<1>(A.size() / flen, [&](size_t k, point rt) {
8198
k *= flen;
8299
vpoint AC, AD, BC, BD;
@@ -87,18 +104,7 @@ namespace cp_algo::math::fft {
87104
auto [Bx, By] = B.at(k);
88105
for (size_t i = 0; i < flen; i++) {
89106
vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
90-
AC += Av * Cv; AD += Av * Dv;
91-
BC += Bv * Cv; BD += Bv * Dv;
92-
real(Cv) = rotate_right(real(Cv));
93-
imag(Cv) = rotate_right(imag(Cv));
94-
real(Dv) = rotate_right(real(Dv));
95-
imag(Dv) = rotate_right(imag(Dv));
96-
auto cx = real(Cv)[0], cy = imag(Cv)[0];
97-
auto dx = real(Dv)[0], dy = imag(Dv)[0];
98-
real(Cv)[0] = cx * real(rt) - cy * imag(rt);
99-
imag(Cv)[0] = cx * imag(rt) + cy * real(rt);
100-
real(Dv)[0] = dx * real(rt) - dy * imag(rt);
101-
imag(Dv)[0] = dx * imag(rt) + dy * real(rt);
107+
do_dot_iter(i, rt, Cv, Dv, Av, Bv, AC, AD, BC, BD);
102108
}
103109
} else {
104110
AC = A.at(k) * Cv;
@@ -123,7 +129,18 @@ namespace cp_algo::math::fft {
123129
dot(C, D, A, B, C);
124130
}
125131

126-
void recover_mod(auto &&C, auto &res, size_t k) {
132+
[[gnu::target("avx2")]] static void do_recover_iter(size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
133+
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
134+
auto Ai = A0 + A1 * split() + A2 * splitsplit + uint64_t(base::modmod());
135+
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
136+
Au = montgomery_mul(Au, mul, mod, imod);
137+
Au = Au >= base::mod() ? Au - base::mod() : Au;
138+
for(size_t j = 0; j < flen; j++) {
139+
res[idx + j].setr(typename base::UInt(Au[j]));
140+
}
141+
}
142+
143+
[[gnu::target("avx2")]] void recover_mod(auto &&C, auto &res, size_t k) {
127144
size_t check = (k + flen - 1) / flen * flen;
128145
assert(res.size() >= check);
129146
size_t n = A.size();
@@ -142,26 +159,16 @@ namespace cp_algo::math::fft {
142159
auto [Ax, Ay] = A.at(i);
143160
auto [Bx, By] = B.at(i);
144161
auto [Cx, Cy] = C.at(i);
145-
auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
146-
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
147-
auto Ai = A0 + A1 * split() + A2 * splitsplit + uint64_t(base::modmod());
148-
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
149-
Au = montgomery_mul(Au, mul, mod, imod);
150-
Au = Au >= base::mod() ? Au - base::mod() : Au;
151-
for(size_t j = 0; j < flen; j++) {
152-
res[i + j].setr(typename base::UInt(Au[j]));
153-
}
154-
};
155-
set_i(i, Ax, Bx, Cx, cur);
162+
do_recover_iter(i, Ax, Bx, Cx, cur, splitsplit, res);
156163
if(i + n < k) {
157-
set_i(i + n, Ay, By, Cy, montgomery_mul(cur, stepn, mod, imod));
164+
do_recover_iter(i + n, Ay, By, Cy, montgomery_mul(cur, stepn, mod, imod), splitsplit, res);
158165
}
159166
cur = montgomery_mul(cur, step4, mod, imod);
160167
}
161168
checkpoint("recover mod");
162169
}
163170

164-
void mul(auto &&C, auto const& D, auto &res, size_t k) {
171+
[[gnu::target("avx2")]] void mul(auto &&C, auto const& D, auto &res, size_t k) {
165172
assert(A.size() == C.size());
166173
size_t n = A.size();
167174
if(!n) {
@@ -174,10 +181,10 @@ namespace cp_algo::math::fft {
174181
C.ifft();
175182
recover_mod(C, res, k);
176183
}
177-
void mul_inplace(auto &&B, auto& res, size_t k) {
184+
[[gnu::target("avx2")]] void mul_inplace(auto &&B, auto& res, size_t k) {
178185
mul(B.A, B.B, res, k);
179186
}
180-
void mul(auto const& B, auto& res, size_t k) {
187+
[[gnu::target("avx2")]] void mul(auto const& B, auto& res, size_t k) {
181188
mul(cvector(B.A), B.B, res, k);
182189
}
183190
std::vector<base, big_alloc<base>> operator *= (dft &B) {
@@ -240,7 +247,7 @@ namespace cp_algo::math::fft {
240247
}
241248

242249
// store mod x^n-k in first half, x^n+k in second half
243-
void mod_split(auto &&x, size_t n, auto k) {
250+
[[gnu::target("avx2")]] void mod_split(auto &&x, size_t n, auto k) {
244251
using base = std::decay_t<decltype(k)>;
245252
dft<base>::init();
246253
assert(std::size(x) == 2 * n);

0 commit comments

Comments
 (0)