Line data Source code
1 : #include <stddef.h>
2 : #include <stdint.h>
3 : #include <stdlib.h>
4 : #include <string.h>
5 : #include <stdio.h>
6 :
7 : #include <nut/debug.h>
8 : #include <nut/modular_math.h>
9 : #include <nut/factorization.h>
10 : #include <nut/dirichlet.h>
11 : #include <nut/dirichlet_powerful.h>
12 : #include <nut/sieves.h>
13 :
14 21 : static bool PfIt_init(nut_PfIt *self, uint64_t max, uint64_t modulus, uint64_t small_primes){
15 21 : if(!max){
16 0 : *self = (nut_PfIt){};
17 0 : return true;
18 : }
19 21 : self->small_primes = small_primes;
20 21 : self->cap = 63 - __builtin_clzll(max);// floor(log_2(max))
21 21 : self->rt_max = nut_u64_nth_root(max, 2);
22 21 : if(!(self->entries = malloc(self->cap*sizeof(*self->entries)))){
23 : return false;
24 21 : }else if(!(self->primes = nut_sieve_primes(self->rt_max, &self->num_primes))){
25 0 : free(self->entries);
26 0 : self->entries = NULL;
27 0 : return false;
28 : }
29 21 : self->entries[0] = (nut_PfStackEnt){.n=1, .hn=1, .i=0};
30 21 : self->len = 1;
31 21 : self->max = max;
32 21 : self->modulus = modulus;
33 21 : return true;
34 : }
35 :
36 2 : bool nut_PfIt_init_fn(nut_PfIt *self, uint64_t max, uint64_t modulus, uint64_t small_primes, int64_t (*h_fn)(uint64_t p, uint64_t pp, uint64_t e, uint64_t m)){
37 2 : if(!PfIt_init(self, max, modulus, small_primes)){
38 : return false;
39 : }
40 2 : self->h_kind = NUT_DIRI_H_FN;
41 2 : self->h_fn = h_fn;
42 2 : return true;
43 : }
44 :
45 18 : bool nut_PfIt_init_hvals(nut_PfIt *restrict self, uint64_t max, uint64_t modulus, uint64_t small_primes, const int64_t *restrict h_vals){
46 18 : if(!PfIt_init(self, max, modulus, small_primes)){
47 : return false;
48 : }
49 18 : self->h_kind = NUT_DIRI_H_VALS;
50 18 : self->h_vals = h_vals;
51 18 : return true;
52 : }
53 :
54 1 : bool nut_PfIt_init_hseqs(nut_PfIt *restrict self, uint64_t max, uint64_t modulus, uint64_t small_primes,
55 : int64_t (*f_fn)(uint64_t p, uint64_t pp, uint64_t e, uint64_t m), int64_t (*g_fn)(uint64_t p, uint64_t pp, uint64_t e, uint64_t m)
56 : ){
57 1 : if(!PfIt_init(self, max, modulus, small_primes)){
58 : return false;
59 : }
60 1 : uint64_t values_cap = 3*self->num_primes;
61 1 : self->h_kind = NUT_DIRI_H_SEQS;
62 1 : self->h_seqs.offsets = malloc(self->num_primes*sizeof(uint64_t));
63 1 : self->h_seqs.values = malloc(values_cap*sizeof(int64_t));
64 1 : if(!self->h_seqs.offsets || !self->h_seqs.values){
65 0 : nut_PfIt_destroy(self);
66 0 : return false;
67 : }
68 1 : uint64_t base_offset = 0;
69 1 : uint64_t curr_max_pow = 63 - __builtin_clzll(max);
70 2 : int64_t *f_series [[gnu::cleanup(cleanup_free)]] = malloc((curr_max_pow + 1)*sizeof(int64_t));
71 2 : int64_t *g_series [[gnu::cleanup(cleanup_free)]] = malloc((curr_max_pow + 1)*sizeof(int64_t));
72 2 : int64_t *h_series [[gnu::cleanup(cleanup_free)]] = malloc((curr_max_pow + 1)*sizeof(int64_t));
73 1 : if(!f_series || !g_series || !h_series){
74 0 : nut_PfIt_destroy(self);
75 0 : return false;
76 : }
77 1 : f_series[0] = g_series[0] = 1;
78 1 : uint64_t curr_max_prime = 2;
79 3 : for(uint64_t i = 0; i < self->num_primes; ++i){
80 2 : uint64_t p = self->primes[i];
81 : // ensure that p^curr_max_pow is still <= max
82 3 : while(curr_max_pow > 2 && p > curr_max_prime){
83 : // this must be a while loop because for very small primes like 2 it's possible for the max power to drop more than 1
84 : // also when the max power is 2 we don't need to check anymore since we've only gathered the primes up to sqrt(max)
85 1 : --curr_max_pow;
86 1 : curr_max_prime = nut_u64_nth_root(max, curr_max_pow);
87 : }
88 : // ensure that the values table is large enough
89 2 : uint64_t min_pow = i >= self->small_primes ? 2 : 1;
90 2 : if(base_offset + curr_max_pow - min_pow >= values_cap){
91 0 : uint64_t new_cap = base_offset + 2*(self->num_primes - i);
92 0 : int64_t *tmp = realloc(self->h_seqs.values, new_cap*sizeof(int64_t));
93 0 : if(!tmp){
94 0 : nut_PfIt_destroy(self);
95 0 : return false;
96 : }
97 0 : self->h_seqs.values = tmp;
98 0 : values_cap = new_cap;
99 : }
100 7 : for(uint64_t pp = 1, e = 1; e <= curr_max_pow; ++e){
101 5 : pp *= p;
102 5 : f_series[e] = f_fn(p, pp, e, modulus);
103 5 : g_series[e] = g_fn(p, pp, e, modulus);
104 : }
105 2 : nut_series_div(curr_max_pow + 1, modulus, h_series, f_series, g_series);
106 2 : self->h_seqs.offsets[i] = base_offset;
107 2 : memcpy(self->h_seqs.values + base_offset, h_series + min_pow, (curr_max_pow + 1 - min_pow)*sizeof(int64_t));
108 2 : base_offset += curr_max_pow + 1 - min_pow;
109 : }
110 : return true;
111 : }
112 :
113 21 : void nut_PfIt_destroy(nut_PfIt *self){
114 21 : free(self->entries);
115 21 : free(self->primes);
116 21 : if(self->h_kind == NUT_DIRI_H_SEQS){
117 1 : free(self->h_seqs.offsets);
118 1 : free(self->h_seqs.values);
119 : }
120 21 : *self = (nut_PfIt){};
121 21 : }
122 :
123 656 : bool nut_PfStack_push(nut_PfIt *restrict self, const nut_PfStackEnt *restrict ent){
124 656 : if(self->len == self->cap){
125 0 : uint64_t new_cap = (self->cap << 1) ?: 8;// it shouldn't be possible to get here with cap = 0, but work around it anyway
126 0 : void *tmp = realloc(self->entries, new_cap*sizeof(*self->entries));
127 0 : if(!tmp){
128 : return false;
129 : }
130 0 : self->cap = new_cap;
131 0 : self->entries = tmp;
132 : }
133 656 : memcpy(self->entries + self->len++, ent, sizeof(*self->entries));
134 656 : return true;
135 : }
136 :
137 677 : bool nut_PfStack_pop(nut_PfIt *restrict self, nut_PfStackEnt *restrict out){
138 677 : if(!self->len){
139 : return false;
140 : }
141 677 : memcpy(out, self->entries + --self->len, sizeof(*self->entries));
142 677 : return true;
143 : }
144 :
145 466 : bool nut_PfIt_next(nut_PfIt *restrict self, nut_PfStackEnt *restrict out){
146 : // Yields (n, h(n) mod m) where n are the O(sqrt x) powerful numbers
147 : // up to x, and h is any multiplicative function.
148 698 : while(self->len){
149 677 : nut_PfStack_pop(self, out);
150 677 : if(out->i >= self->num_primes){
151 : return true;
152 : }
153 635 : uint64_t p = self->primes[out->i];
154 635 : uint64_t min_pow = out->i >= self->small_primes ? 2 : 1;
155 635 : if(min_pow == 2){
156 635 : if(p*p > self->max/out->n){
157 : return true;
158 : }
159 : }else{
160 0 : if(p > self->max/out->n){
161 : return true;
162 : }
163 : }
164 232 : if(!nut_PfStack_push(self, &(nut_PfStackEnt){.n=out->n, .hn=out->hn, .i= out->i + 1})){
165 0 : return false;
166 : }
167 888 : for(uint64_t pp = min_pow == 2 ? p : 1, e = min_pow; !__builtin_mul_overflow(pp, p, &pp) && pp <= self->max/out->n; ++e){
168 424 : int64_t v = out->hn;
169 424 : if(self->h_kind == NUT_DIRI_H_VALS){
170 396 : v *= self->h_vals[e];
171 28 : }else if(self->h_kind == NUT_DIRI_H_FN){
172 25 : v *= self->h_fn(p, pp, e, self->modulus);
173 3 : }else if(self->h_kind == NUT_DIRI_H_SEQS){
174 3 : v *= self->h_seqs.values[self->h_seqs.offsets[out->i] + e - min_pow];
175 : }else{
176 0 : return false;
177 : }
178 424 : if(self->modulus){
179 396 : v = nut_i64_mod(v, self->modulus);
180 : }
181 424 : if(!nut_PfStack_push(self, &(nut_PfStackEnt){.n = out->n*pp, .hn = v, .i = out->i + 1})){
182 0 : return false;
183 : }
184 : }
185 : }
186 : return false;
187 : }
188 :
189 19 : bool nut_Diri_sum_adjusted(int64_t *restrict out, const nut_Diri *restrict g_tbl, nut_PfIt *pf_it){
190 38 : int64_t *g_dense [[gnu::cleanup(cleanup_free)]] = malloc((g_tbl->y + 1)*sizeof(int64_t));
191 19 : if(!g_dense || pf_it->modulus > INT64_MAX){
192 0 : return false;
193 : }
194 19 : g_dense[0] = 0;
195 19 : int64_t m = pf_it->modulus;
196 285 : for(int64_t i = 1; i <= g_tbl->y; ++i){
197 266 : int64_t term = g_dense[i - 1] + nut_Diri_get_dense(g_tbl, i);
198 266 : g_dense[i] = m ? nut_i64_mod(term, m) : term;
199 : }
200 19 : int64_t res = 0;
201 19 : nut_PfStackEnt ent;
202 456 : while(nut_PfIt_next(pf_it, &ent)){
203 437 : int64_t Gn;
204 437 : if(ent.n >= (uint64_t)g_tbl->yinv){
205 361 : Gn = g_dense[g_tbl->x/ent.n];
206 : }else{
207 76 : Gn = nut_Diri_get_sparse(g_tbl, ent.n);
208 : }
209 437 : res += ent.hn*Gn;
210 437 : if(m){
211 414 : res = nut_i64_mod(res, m);
212 : }
213 : }
214 19 : *out = res;
215 19 : return true;
216 : }
217 :
218 0 : bool nut_Diri_sum_u_adjusted(int64_t *restrict out, nut_PfIt *pf_it){
219 0 : int64_t m = pf_it->modulus;
220 0 : int64_t res = 0;
221 0 : uint64_t max = pf_it->max;
222 0 : nut_PfStackEnt ent;
223 0 : while(nut_PfIt_next(pf_it, &ent)){
224 0 : res += ent.hn*(max/ent.n);
225 0 : if(m){
226 0 : res = nut_i64_mod(res, m);
227 : }
228 : }
229 0 : *out = res;
230 0 : return true;
231 : }
232 :
233 20 : void nut_series_div(uint64_t n, int64_t m, int64_t h[restrict static n], int64_t f[restrict static n], int64_t g[restrict static n]){
234 189 : for(uint64_t e = 0; e < n; ++e){
235 169 : int64_t term = f[e];
236 826 : for(uint64_t k = 1; k <= e; ++k){
237 657 : term -= g[k]*h[e - k];
238 657 : if(m){
239 648 : term %= m;
240 : }
241 : }
242 169 : h[e] = (m && term < 0) ? m + term : term;
243 : }
244 20 : }
245 :
|