Skip to content

Commit 671b2cb

Browse files
committed
Don't use [[gnu::target("avx2")]]
1 parent e8d490b commit 671b2cb

File tree

11 files changed

+100
-79
lines changed

11 files changed

+100
-79
lines changed

.verify-helper/config.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[[languages.cpp.environments]]
22
CXX = "g++"
3-
CXXFLAGS = ["-std=c++23", "-Wall", "-Wextra", "-Wconversion", "-Werror", "-pedantic", "-O2"]
3+
CXXFLAGS = ["-std=c++23", "-Wall", "-Wextra", "-Wconversion", "-Werror", "-Wno-psabi", "-pedantic", "-O2"]

cp-algo/math/cvector.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace cp_algo::math::fft {
1515
using point = complex<ftype>;
1616
using vpoint = complex<vftype>;
1717
static constexpr vftype vz = {};
18-
[[gnu::target("avx2")]] vpoint vi(vpoint const& r) {
18+
simd_target vpoint vi(vpoint const& r) {
1919
return {-imag(r), real(r)};
2020
}
2121

@@ -30,7 +30,7 @@ namespace cp_algo::math::fft {
3030
vpoint& at(size_t k) {return r[k / flen];}
3131
vpoint at(size_t k) const {return r[k / flen];}
3232
template<class pt = point>
33-
void set(size_t k, pt const& t) {
33+
simd_inline void set(size_t k, pt const& t) {
3434
if constexpr(std::is_same_v<pt, point>) {
3535
real(r[k / flen])[k % flen] = real(t);
3636
imag(r[k / flen])[k % flen] = imag(t);
@@ -39,7 +39,7 @@ namespace cp_algo::math::fft {
3939
}
4040
}
4141
template<class pt = point>
42-
[[gnu::target("avx2")]] pt get(size_t k) const {
42+
simd_inline pt get(size_t k) const {
4343
if constexpr(std::is_same_v<pt, point>) {
4444
return {real(r[k / flen])[k % flen], imag(r[k / flen])[k % flen]};
4545
} else {
@@ -79,18 +79,18 @@ namespace cp_algo::math::fft {
7979
return roots[std::bit_width(n)];
8080
}
8181
template<int step>
82-
[[gnu::target("avx2")]] static void exec_on_eval(size_t n, size_t k, auto &&callback) {
82+
simd_target static void exec_on_eval(size_t n, size_t k, auto &&callback) {
8383
callback(k, root(4 * step * n) * eval_point(step * k));
8484
}
8585
template<int step>
86-
[[gnu::target("avx2")]] static void exec_on_evals(size_t n, auto &&callback) {
86+
simd_target static void exec_on_evals(size_t n, auto &&callback) {
8787
point factor = root(4 * step * n);
8888
for(size_t i = 0; i < n; i++) {
8989
callback(i, factor * eval_point(step * i));
9090
}
9191
}
9292

93-
[[gnu::target("avx2")]] static void do_dot_iter(point rt, vpoint& Bv, vpoint const& Av, vpoint& res) {
93+
simd_target static void do_dot_iter(point rt, vpoint& Bv, vpoint const& Av, vpoint& res) {
9494
res += Av * Bv;
9595
real(Bv) = rotate_right(real(Bv));
9696
imag(Bv) = rotate_right(imag(Bv));
@@ -99,7 +99,7 @@ namespace cp_algo::math::fft {
9999
imag(Bv)[0] = x * imag(rt) + y * real(rt);
100100
}
101101

102-
[[gnu::target("avx2")]] void dot(cvector const& t) {
102+
simd_target void dot(cvector const& t) {
103103
size_t n = this->size();
104104
exec_on_evals<1>(n / flen, [&](size_t k, point rt) {
105105
k *= flen;
@@ -115,7 +115,7 @@ namespace cp_algo::math::fft {
115115
checkpoint("dot");
116116
}
117117
template<bool partial = true>
118-
[[gnu::target("avx2")]] void ifft() {
118+
simd_target void ifft() {
119119
size_t n = size();
120120
if constexpr (!partial) {
121121
point pi(0, 1);
@@ -177,7 +177,7 @@ namespace cp_algo::math::fft {
177177
}
178178
}
179179
template<bool partial = true>
180-
[[gnu::target("avx2")]] void fft() {
180+
simd_target void fft() {
181181
size_t n = size();
182182
bool parity = std::countr_zero(n) % 2;
183183
for(size_t leaf = 0; leaf < n; leaf += 4 * flen) {

cp-algo/math/factorials.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace cp_algo::math {
1111
template<bool use_bump_alloc = false, int maxn = -1>
12-
[[gnu::target("avx2")]] auto facts(auto const& args) {
12+
simd_target auto facts(auto const& args) {
1313
static_assert(!use_bump_alloc || maxn > 0, "maxn must be set if use_bump_alloc is true");
1414
constexpr int max_mod = 1'000'000'000;
1515
constexpr int accum = 4;

cp-algo/math/fft.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace cp_algo::math::fft {
2929
}
3030
}
3131

32-
[[gnu::target("avx2")]] static std::pair<vftype, vftype>
32+
simd_target static std::pair<vftype, vftype>
3333
do_split(auto const& a, size_t idx, u64x4 mul) {
3434
if(idx >= std::size(a)) {
3535
return std::pair{vftype(), vftype()};
@@ -48,7 +48,7 @@ namespace cp_algo::math::fft {
4848
}
4949

5050
dft(size_t n): A(n), B(n) {init();}
51-
[[gnu::target("avx2")]] dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
51+
simd_target dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
5252
init();
5353
base b2x32 = bpow(base(2), 32);
5454
u64x4 cur = {
@@ -77,7 +77,7 @@ namespace cp_algo::math::fft {
7777
}
7878
}
7979
}
80-
[[gnu::target("avx2")]] static void do_dot_iter(point rt, vpoint& Cv, vpoint& Dv, vpoint const& Av, vpoint const& Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
80+
simd_target static void do_dot_iter(point rt, vpoint& Cv, vpoint& Dv, vpoint const& Av, vpoint const& Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
8181
AC += Av * Cv; AD += Av * Dv;
8282
BC += Bv * Cv; BD += Bv * Dv;
8383
real(Cv) = rotate_right(real(Cv));
@@ -93,7 +93,7 @@ namespace cp_algo::math::fft {
9393
}
9494

9595
template<bool overwrite = true, bool partial = true>
96-
[[gnu::target("avx2")]] void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
96+
simd_target void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
9797
cvector::exec_on_evals<1>(A.size() / flen, [&](size_t k, point rt) {
9898
k *= flen;
9999
vpoint AC, AD, BC, BD;
@@ -129,7 +129,7 @@ namespace cp_algo::math::fft {
129129
dot(C, D, A, B, C);
130130
}
131131

132-
[[gnu::target("avx2")]] static void do_recover_iter(size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
132+
simd_target static void do_recover_iter(size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
133133
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
134134
auto Ai = A0 + A1 * split() + A2 * splitsplit + uint64_t(base::modmod());
135135
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
@@ -140,7 +140,7 @@ namespace cp_algo::math::fft {
140140
}
141141
}
142142

143-
[[gnu::target("avx2")]] void recover_mod(auto &&C, auto &res, size_t k) {
143+
simd_target void recover_mod(auto &&C, auto &res, size_t k) {
144144
size_t check = (k + flen - 1) / flen * flen;
145145
assert(res.size() >= check);
146146
size_t n = A.size();
@@ -168,7 +168,7 @@ namespace cp_algo::math::fft {
168168
checkpoint("recover mod");
169169
}
170170

171-
[[gnu::target("avx2")]] void mul(auto &&C, auto const& D, auto &res, size_t k) {
171+
simd_target void mul(auto &&C, auto const& D, auto &res, size_t k) {
172172
assert(A.size() == C.size());
173173
size_t n = A.size();
174174
if(!n) {
@@ -181,10 +181,10 @@ namespace cp_algo::math::fft {
181181
C.ifft();
182182
recover_mod(C, res, k);
183183
}
184-
[[gnu::target("avx2")]] void mul_inplace(auto &&B, auto& res, size_t k) {
184+
simd_target void mul_inplace(auto &&B, auto& res, size_t k) {
185185
mul(B.A, B.B, res, k);
186186
}
187-
[[gnu::target("avx2")]] void mul(auto const& B, auto& res, size_t k) {
187+
simd_target void mul(auto const& B, auto& res, size_t k) {
188188
mul(cvector(B.A), B.B, res, k);
189189
}
190190
big_vector<base> operator *= (dft &B) {
@@ -247,7 +247,7 @@ namespace cp_algo::math::fft {
247247
}
248248

249249
// store mod x^n-k in first half, x^n+k in second half
250-
[[gnu::target("avx2")]] void mod_split(auto &&x, size_t n, auto k) {
250+
simd_target void mod_split(auto &&x, size_t n, auto k) {
251251
using base = std::decay_t<decltype(k)>;
252252
dft<base>::init();
253253
assert(std::size(x) == 2 * n);

cp-algo/math/fft64.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace cp_algo::math::fft {
4646
}
4747
}
4848

49-
[[gnu::target("avx2")]] static void do_dot_iter(point rt, std::array<vpoint, 4>& B, std::array<vpoint, 4> const& A, std::array<vpoint, 4>& C) {
49+
simd_target static void do_dot_iter(point rt, std::array<vpoint, 4>& B, std::array<vpoint, 4> const& A, std::array<vpoint, 4>& C) {
5050
for(size_t k = 0; k < 4; k++) {
5151
for(size_t i = 0; i <= k; i++) {
5252
C[k] += A[i] * B[k - i];

cp-algo/util/bit.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ namespace cp_algo {
3838
});
3939
}
4040

41-
[[gnu::target("avx2")]] inline uint32_t read_bits(char const* p) {
41+
simd_inline uint32_t read_bits(char const* p) {
4242
return _mm256_movemask_epi8(__m256i(vector_cast<u8x32 const>(p[0]) + (127 - '0')));
4343
}
44-
[[gnu::target("avx2")]] inline uint64_t read_bits64(char const* p) {
44+
simd_inline uint64_t read_bits64(char const* p) {
4545
return read_bits(p) | (uint64_t(read_bits(p + 32)) << 32);
4646
}
4747

48-
[[gnu::target("avx2")]] inline void write_bits(char *p, uint32_t bits) {
48+
simd_inline void write_bits(char *p, uint32_t bits) {
4949
static constexpr u8x32 shuffler = {
5050
0, 0, 0, 0, 0, 0, 0, 0,
5151
1, 1, 1, 1, 1, 1, 1, 1,
@@ -63,7 +63,7 @@ namespace cp_algo {
6363
p[z] = shuffled[z] & mask[z] ? '1' : '0';
6464
}
6565
}
66-
[[gnu::target("avx2")]] inline void write_bits64(char *p, uint64_t bits) {
66+
simd_inline void write_bits64(char *p, uint64_t bits) {
6767
write_bits(p, uint32_t(bits));
6868
write_bits(p + 32, uint32_t(bits >> 32));
6969
}

cp-algo/util/checkpoint.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
#include <string>
66
#include <map>
77
namespace cp_algo {
8+
#ifdef CP_ALGO_CHECKPOINT
9+
std::map<std::string, double> checkpoints;
10+
double last;
11+
#endif
812
template<bool final = false>
9-
void checkpoint([[maybe_unused]] auto const& _msg = "") {
13+
void checkpoint([[maybe_unused]] auto const& _msg) {
1014
#ifdef CP_ALGO_CHECKPOINT
1115
std::string msg = _msg;
12-
static std::map<std::string, double> checkpoints;
13-
static double last = 0;
1416
double now = (double)clock() / CLOCKS_PER_SEC;
1517
double delta = now - last;
1618
last = now;
@@ -25,5 +27,9 @@ namespace cp_algo {
2527
}
2628
#endif
2729
}
30+
template<bool final = false>
31+
void checkpoint() {
32+
checkpoint<final>("");
33+
}
2834
}
2935
#endif // CP_ALGO_UTIL_CHECKPOINT_HPP

cp-algo/util/complex.hpp

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,45 @@ namespace cp_algo {
99
struct complex {
1010
using value_type = T;
1111
T x, y;
12-
constexpr complex(): x(), y() {}
13-
constexpr complex(T const& x): x(x), y() {}
14-
constexpr complex(T const& x, T const& y): x(x), y(y) {}
15-
[[gnu::target("avx2")]] complex& operator *= (T const& t) {x *= t; y *= t; return *this;}
16-
[[gnu::target("avx2")]] complex& operator /= (T const& t) {x /= t; y /= t; return *this;}
17-
[[gnu::target("avx2")]] complex operator * (T const& t) const {return complex(*this) *= t;}
18-
[[gnu::target("avx2")]] complex operator / (T const& t) const {return complex(*this) /= t;}
19-
[[gnu::target("avx2")]] complex& operator += (complex const& t) {x += t.x; y += t.y; return *this;}
20-
[[gnu::target("avx2")]] complex& operator -= (complex const& t) {x -= t.x; y -= t.y; return *this;}
21-
[[gnu::target("avx2")]] complex operator * (complex const& t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
22-
[[gnu::target("avx2")]] complex operator / (complex const& t) const {return *this * t.conj() / t.norm();}
23-
[[gnu::target("avx2")]] complex operator + (complex const& t) const {return complex(*this) += t;}
24-
[[gnu::target("avx2")]] complex operator - (complex const& t) const {return complex(*this) -= t;}
25-
[[gnu::target("avx2")]] complex& operator *= (complex const& t) {return *this = *this * t;}
26-
[[gnu::target("avx2")]] complex& operator /= (complex const& t) {return *this = *this / t;}
27-
[[gnu::target("avx2")]] complex operator - () const {return {-x, -y};}
28-
[[gnu::target("avx2")]] complex conj() const {return {x, -y};}
29-
[[gnu::target("avx2")]] T norm() const {return x * x + y * y;}
30-
[[gnu::target("avx2")]] T abs() const {return std::sqrt(norm());}
31-
[[gnu::target("avx2")]] T const real() const {return x;}
32-
[[gnu::target("avx2")]] T const imag() const {return y;}
33-
[[gnu::target("avx2")]] T& real() {return x;}
34-
[[gnu::target("avx2")]] T& imag() {return y;}
35-
[[gnu::target("avx2")]] static constexpr complex polar(T r, T theta) {return {T(r * cos(theta)), T(r * sin(theta))};}
36-
[[gnu::target("avx2")]] auto operator <=> (complex const& t) const = default;
12+
inline constexpr complex(): x(), y() {}
13+
inline constexpr complex(T const& x): x(x), y() {}
14+
inline constexpr complex(T const& x, T const& y): x(x), y(y) {}
15+
inline complex& operator *= (T const& t) {x *= t; y *= t; return *this;}
16+
inline complex& operator /= (T const& t) {x /= t; y /= t; return *this;}
17+
inline complex operator * (T const& t) const {return complex(*this) *= t;}
18+
inline complex operator / (T const& t) const {return complex(*this) /= t;}
19+
inline complex& operator += (complex const& t) {x += t.x; y += t.y; return *this;}
20+
inline complex& operator -= (complex const& t) {x -= t.x; y -= t.y; return *this;}
21+
inline complex operator * (complex const& t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
22+
inline complex operator / (complex const& t) const {return *this * t.conj() / t.norm();}
23+
inline complex operator + (complex const& t) const {return complex(*this) += t;}
24+
inline complex operator - (complex const& t) const {return complex(*this) -= t;}
25+
inline complex& operator *= (complex const& t) {return *this = *this * t;}
26+
inline complex& operator /= (complex const& t) {return *this = *this / t;}
27+
inline complex operator - () const {return {-x, -y};}
28+
inline complex conj() const {return {x, -y};}
29+
inline T norm() const {return x * x + y * y;}
30+
inline T abs() const {return std::sqrt(norm());}
31+
inline T const real() const {return x;}
32+
inline T const imag() const {return y;}
33+
inline T& real() {return x;}
34+
inline T& imag() {return y;}
35+
inline static constexpr complex polar(T r, T theta) {return {T(r * cos(theta)), T(r * sin(theta))};}
36+
inline auto operator <=> (complex const& t) const = default;
3737
};
38-
template<typename T> [[gnu::target("avx2")]] complex<T> conj(complex<T> const& x) {return x.conj();}
39-
template<typename T> [[gnu::target("avx2")]] T norm(complex<T> const& x) {return x.norm();}
40-
template<typename T> [[gnu::target("avx2")]] T abs(complex<T> const& x) {return x.abs();}
41-
template<typename T> [[gnu::target("avx2")]] T& real(complex<T> &x) {return x.real();}
42-
template<typename T> [[gnu::target("avx2")]] T& imag(complex<T> &x) {return x.imag();}
43-
template<typename T> [[gnu::target("avx2")]] T const real(complex<T> const& x) {return x.real();}
44-
template<typename T> [[gnu::target("avx2")]] T const imag(complex<T> const& x) {return x.imag();}
38+
template<typename T> inline complex<T> conj(complex<T> const& x) {return x.conj();}
39+
template<typename T> inline T norm(complex<T> const& x) {return x.norm();}
40+
template<typename T> inline T abs(complex<T> const& x) {return x.abs();}
41+
template<typename T> inline T& real(complex<T> &x) {return x.real();}
42+
template<typename T> inline T& imag(complex<T> &x) {return x.imag();}
43+
template<typename T> inline T const real(complex<T> const& x) {return x.real();}
44+
template<typename T> inline T const imag(complex<T> const& x) {return x.imag();}
4545
template<typename T>
46-
[[gnu::target("avx2")]] constexpr complex<T> polar(T r, T theta) {
46+
inline constexpr complex<T> polar(T r, T theta) {
4747
return complex<T>::polar(r, theta);
4848
}
4949
template<typename T>
50-
std::ostream& operator << (std::ostream &out, complex<T> const& x) {
50+
inline std::ostream& operator << (std::ostream &out, complex<T> const& x) {
5151
return out << x.real() << ' ' << x.imag();
5252
}
5353
}

0 commit comments

Comments
 (0)