-
Notifications
You must be signed in to change notification settings - Fork 3
/
poly.h
117 lines (98 loc) · 3.25 KB
/
poly.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#pragma once
#include "ntt.h"
#include "singleton.h"
#include "snippets/min_pow_of_two.h"
#include <cstdlib>
#include <vector>
template <typename Mod_> struct PolyT : public std::vector<Mod_> {
using Mod = Mod_;
using Vector = std::vector<Mod>;
static void assert_power_of_two(int n) { Ntt::assert_power_of_two(n); }
static void reserve(int n) { return ntt().reserve(n); }
template <int i> static Mod *raw_buffer() {
return ntt().template raw_buffer<i>();
}
static Mod power_of_two_inv(int n) { return ntt().power_of_two_inv(n); }
static void dif(int n, Mod *a) { ntt().dif(n, a); }
static void dit(int n, Mod *a) { ntt().dit(n, a); }
static void copy_and_fill0(int n, Mod *dst, int m, const Mod *src) {
m = std::min(n, m);
std::copy(src, src + m, dst);
std::fill(dst + m, dst + n, Mod{0});
}
static void copy_and_fill0(int n, Mod *dst, const std::vector<Mod> &src) {
copy_and_fill0(n, dst, src.size(), src.data());
}
static void dot_product_and_dit(int n, Mod *out, const Mod *a, const Mod *b) {
auto inv_n = power_of_two_inv(n);
for (int i = 0; i < n; ++i) {
out[i] = inv_n * a[i] * b[i];
}
ntt().dit(n, out);
}
PolyT() : std::vector<Mod>{Mod{}} {}
explicit PolyT(size_t size) : std::vector<Mod>(size) {}
explicit PolyT(Vector &&v) : std::vector<Mod>{std::move(v)} {}
explicit PolyT(const Vector &v) : std::vector<Mod>{v} {}
explicit PolyT(const std::initializer_list<Mod> &v) : std::vector<Mod>{v} {}
const Vector &vector() const { return *this; }
int deg() const { return static_cast<int>(std::vector<Mod>::size()) - 1; }
bool operator==(const PolyT &o) const { return vector() == o.vector(); }
bool operator!=(const PolyT &o) const { return !(*this == o); }
PolyT operator+(const PolyT &o) const {
auto copy = *this;
copy += o;
return copy;
}
PolyT &operator+=(const PolyT &o) {
if (deg() < o.deg()) {
this->resize(o.size());
}
for (int i = 0; i <= o.deg(); i++) {
(*this)[i] += o[i];
}
return *this;
}
PolyT operator-(const PolyT &o) const {
int max_deg = std::max(deg(), o.deg());
PolyT r(max_deg + 1);
int min_deg = std::min(deg(), o.deg());
for (int i = 0; i <= min_deg; ++i) {
r[i] = (*this)[i] - o[i];
}
for (int i = min_deg + 1; i <= deg(); ++i) {
r[i] = (*this)[i];
}
for (int i = min_deg + 1; i <= o.deg(); ++i) {
r[i] = -o[i];
}
return r;
}
PolyT &operator-=(const PolyT &o) { return *this = *this - o; }
PolyT operator*(const PolyT &o) const {
int deg_plus_1 = deg() + o.deg() + 1;
if (deg_plus_1 <= 16) {
PolyT result(deg_plus_1);
for (int i = 0; i <= deg(); ++i) {
for (int j = 0; j <= o.deg(); ++j) {
result[i + j] += (*this)[i] * o[j];
}
}
return result;
}
int n = min_pow_of_two(deg_plus_1);
reserve(n);
Mod *b0 = raw_buffer<0>();
Mod *b1 = raw_buffer<1>();
copy_and_fill0(n, b0, *this);
dif(n, b0);
copy_and_fill0(n, b1, o);
dif(n, b1);
dot_product_and_dit(n, b0, b0, b1);
return PolyT(std::vector<Mod>{b0, b0 + deg_plus_1});
}
PolyT &operator*=(const PolyT &o) { return *this = *this * o; }
private:
using Ntt = NttT<Mod>;
static Ntt &ntt() { return singleton<Ntt>(); }
};