Line data Source code
1 : #include <string.h>
2 : #include <ctype.h>
3 :
4 : #include <nut/modular_math.h>
5 : #include <nut/factorization.h>
6 : #include <nut/polynomial.h>
7 :
8 30438 : bool nut_Poly_init(nut_Poly *f, uint64_t reserve){
9 30438 : reserve = reserve ?: 4;
10 30438 : f->coeffs = malloc(reserve*sizeof(int64_t));
11 30438 : if(!f->coeffs){
12 : return false;
13 : }
14 30438 : f->len = 1;
15 30438 : f->coeffs[0] = 0;
16 30438 : f->cap = reserve;
17 30438 : return true;
18 : }
19 :
20 30444 : void nut_Poly_destroy(nut_Poly *f){
21 30444 : free(f->coeffs);
22 30444 : *f = (nut_Poly){};
23 30444 : }
24 :
25 0 : int nut_Poly_cmp(const nut_Poly *a, const nut_Poly *b){
26 0 : uint64_t min_len;
27 0 : if(a->len < b->len){
28 0 : for(uint64_t i = b->len - 1; i >= a->len; --i){
29 0 : if(b->coeffs[i] > 0){
30 : return -1;
31 0 : }else if(b->coeffs[i] < 0){
32 : return 1;
33 0 : }else if(!i){
34 : return 0;
35 : }
36 : }
37 0 : min_len = a->len;
38 0 : }else if(a->len > b->len){
39 0 : for(uint64_t i = a->len - 1; i >= b->len; --i){
40 0 : if(a->coeffs[i] > 0){
41 : return 1;
42 0 : }else if(a->coeffs[i] < 0){
43 : return -1;
44 0 : }else if(!i){
45 : return 0;
46 : }
47 : }
48 0 : min_len = b->len;
49 : }else{
50 0 : min_len = a->len;
51 : }
52 0 : for(uint64_t i = min_len; i-- > 0;){
53 0 : if(a->coeffs[i] < b->coeffs[i]){
54 : return -1;
55 0 : }else if(a->coeffs[i] < b->coeffs[i]){
56 : return 1;
57 : }
58 : }
59 : return 0;
60 : }
61 :
62 1 : bool nut_Roots_init(nut_Roots *roots, uint64_t reserve){
63 1 : reserve = reserve ?: 4;
64 1 : roots->roots = malloc(reserve*sizeof(int64_t));
65 1 : if(!roots->roots){
66 : return false;
67 : }
68 1 : roots->len = 0;
69 1 : roots->cap = reserve;
70 1 : return true;
71 : }
72 :
73 1 : void nut_Roots_destroy(nut_Roots *roots){
74 1 : free(roots->roots);
75 1 : *roots = (nut_Roots){};
76 1 : }
77 :
78 25554069 : int64_t nut_Poly_eval_modn(const nut_Poly *f, int64_t x, int64_t n){
79 25554069 : if(!x){
80 5001 : return f->coeffs[0];
81 : }
82 25549068 : int64_t r = 0;
83 255460664 : for(uint64_t i = f->len; i > 0;){
84 229911596 : --i;
85 229911596 : r = nut_i64_mod(r*x + f->coeffs[i], n);
86 : }
87 : return r;
88 : }
89 :
90 580067 : bool nut_Poly_ensure_cap(nut_Poly *f, uint64_t cap){
91 580067 : if(f->cap < cap){
92 15706 : int64_t *tmp = realloc(f->coeffs, cap*sizeof(int64_t));
93 15706 : if(!tmp){
94 : return false;
95 : }
96 15706 : f->coeffs = tmp;
97 15706 : f->cap = cap;
98 : }
99 : return true;
100 : }
101 :
102 2 : bool nut_Poly_zero_extend(nut_Poly *f, uint64_t len){
103 2 : if(!nut_Poly_ensure_cap(f, len)){
104 : return false;
105 : }
106 4 : for(uint64_t i = f->len; i < len; ++i){
107 2 : f->coeffs[i] = false;
108 : }
109 2 : if(len > f->len){
110 1 : f->len = len;
111 : }
112 : return true;
113 : }
114 :
115 523 : bool nut_Roots_ensure_cap(nut_Roots *roots, uint64_t cap){
116 523 : if(roots->cap < cap){
117 0 : int64_t *tmp = realloc(roots->roots, cap*sizeof(int64_t));
118 0 : if(!tmp){
119 : return false;
120 : }
121 0 : roots->roots = tmp;
122 0 : roots->cap = cap;
123 : }
124 : return true;
125 : }
126 :
127 6523 : bool nut_Poly_add_modn(nut_Poly *h, const nut_Poly *f, const nut_Poly *g, int64_t n){
128 6523 : if(!nut_Poly_ensure_cap(h, f->len > g->len ? f->len : g->len)){
129 : return false;
130 : }
131 : uint64_t i;
132 30046 : for(i = 0; i < f->len && i < g->len; ++i){
133 23523 : h->coeffs[i] = nut_i64_mod(f->coeffs[i] + g->coeffs[i], n);
134 : }
135 23909 : for(; i < f->len; ++i){
136 17386 : h->coeffs[i] = f->coeffs[i];
137 : }
138 10659 : for(; i < g->len; ++i){
139 4136 : h->coeffs[i] = g->coeffs[i];
140 : }
141 6523 : int same_len = f->len == g->len;
142 6523 : h->len = f->len > g->len ? f->len : g->len;
143 6523 : if(same_len){
144 1358 : nut_Poly_normalize(h);
145 : }
146 : return true;
147 : }
148 :
149 8001 : bool nut_Poly_sub_modn(nut_Poly *h, const nut_Poly *f, const nut_Poly *g, int64_t n){
150 8001 : if(!nut_Poly_ensure_cap(h, f->len > g->len ? f->len : g->len)){
151 : return false;
152 : }
153 : uint64_t i;
154 36003 : for(i = 0; i < f->len && i < g->len; ++i){
155 28002 : h->coeffs[i] = nut_i64_mod(f->coeffs[i] - g->coeffs[i], n);
156 : }
157 33857 : for(; i < f->len; ++i){
158 25856 : h->coeffs[i] = f->coeffs[i];
159 : }
160 12001 : for(; i < g->len; ++i){
161 4000 : h->coeffs[i] = g->coeffs[i] ? n - g->coeffs[i] : 0;
162 : }
163 8001 : int same_len = f->len == g->len;
164 8001 : h->len = f->len > g->len ? f->len : g->len;
165 8001 : if(same_len){
166 1523 : nut_Poly_normalize(h);
167 : }
168 : return true;
169 : }
170 :
171 0 : bool nut_Poly_dot_modn(nut_Poly *h, const nut_Poly *f, const nut_Poly *g, int64_t n){
172 0 : if(!nut_Poly_ensure_cap(h, f->len < g->len ? f->len : g->len)){
173 : return false;
174 : }
175 : uint64_t i;
176 0 : for(i = 0; i < f->len && i < g->len; ++i){
177 0 : h->coeffs[i] = nut_i64_mod(f->coeffs[i]*g->coeffs[i], n);
178 : }
179 0 : h->len = f->len < g->len ? f->len : g->len;
180 0 : nut_Poly_normalize(h);
181 0 : return true;
182 : }
183 :
184 75522 : bool nut_Poly_copy(nut_Poly *restrict g, const nut_Poly *restrict f){
185 75522 : if(!nut_Poly_ensure_cap(g, f->len)){
186 : return false;
187 : }
188 75522 : memcpy(g->coeffs, f->coeffs, f->len*sizeof(int64_t));
189 75522 : g->len = f->len;
190 75522 : return true;
191 : }
192 :
193 55941 : bool nut_Poly_setconst(nut_Poly *f, int64_t c){
194 55941 : if(!nut_Poly_ensure_cap(f, 1)){
195 : return false;
196 : }
197 55941 : f->coeffs[0] = c;
198 55941 : f->len = 1;
199 55941 : return true;
200 : }
201 :
202 22510 : bool nut_Poly_scale_modn(nut_Poly *g, const nut_Poly *f, int64_t a, int64_t n){
203 22510 : if(!a){
204 27 : return nut_Poly_setconst(g, 0);
205 : }
206 22483 : if(!nut_Poly_ensure_cap(g, f->len)){
207 : return false;
208 : }
209 114536 : for(uint64_t i = 0; i < f->len; ++i){
210 92053 : g->coeffs[i] = nut_i64_mod(a*f->coeffs[i], n);
211 : }
212 22483 : g->len = f->len;
213 22483 : nut_Poly_normalize(g);
214 22483 : return true;
215 : }
216 :
217 130061 : bool nut_Poly_mul_modn(nut_Poly *restrict h, const nut_Poly *f, const nut_Poly *g, int64_t n){
218 130061 : if(f->len == 1){
219 2734 : return nut_Poly_scale_modn(h, g, f->coeffs[0], n);
220 127327 : }else if(g->len == 1){
221 1000 : return nut_Poly_scale_modn(h, f, g->coeffs[0], n);
222 : }
223 126327 : if(!nut_Poly_ensure_cap(h, f->len + g->len - 1)){
224 : return false;
225 : }
226 1441366 : for(uint64_t k = 0; k < f->len + g->len - 1; ++k){
227 1315039 : h->coeffs[k] = 0;
228 258532940 : for(uint64_t i = k >= g->len ? k - g->len + 1 : 0; i < f->len && i <= k; ++i){
229 255902862 : h->coeffs[k] = nut_i64_mod(h->coeffs[k] + f->coeffs[i]*g->coeffs[k - i], n);
230 : }
231 : }
232 126327 : h->len = f->len + g->len - 1;
233 126327 : nut_Poly_normalize(h);
234 126327 : return true;
235 : }
236 :
237 0 : bool nut_Poly_pow_modn(nut_Poly *restrict g, const nut_Poly *f, uint64_t e, int64_t n, uint64_t cn, nut_Poly tmps[restrict static 2]){
238 : //tmps: st, rt
239 0 : if(f->len > cn + 1){
240 : return false;
241 0 : }else if(!e){
242 0 : return nut_Poly_setconst(g, 1); // TODO: deal with 0**0 case
243 0 : }else if(e == 1){
244 0 : return nut_Poly_copy(g, f);
245 : }
246 0 : e = 1 + (e - 1)%cn;
247 0 : if(f->len == 1){
248 0 : return nut_Poly_setconst(g, nut_u64_pow(f->coeffs[0], e));
249 0 : }else if(!nut_Poly_copy(tmps + 0, f) || !nut_Poly_ensure_cap(tmps + 0, 2*cn + 1) || !nut_Poly_ensure_cap(tmps + 1, 2*cn + 1)){
250 0 : return false;
251 : }
252 0 : nut_Poly *t = g, *s = tmps + 0, *r = tmps + 1;
253 0 : while(e%2 == 0){
254 0 : nut_Poly_mul_modn(t, s, s, n);
255 0 : nut_Poly_normalize_exps_modn(t, cn);
256 : {
257 0 : void *tmp = t;
258 0 : t = s;
259 0 : s = tmp;
260 : }//s = s*s
261 0 : e >>= 1;
262 : }
263 0 : if(!nut_Poly_copy(r, s)){
264 : return false;
265 : }
266 0 : while((e >>= 1)){
267 0 : nut_Poly_mul_modn(t, s, s, n);
268 0 : nut_Poly_normalize_exps_modn(t, cn);
269 : {
270 0 : void *tmp = t;
271 0 : t = s;
272 0 : s = tmp;
273 : }//s = s*s
274 0 : if(e%2){
275 0 : nut_Poly_mul_modn(t, r, s, n);
276 0 : nut_Poly_normalize_exps_modn(t, cn);
277 : {
278 0 : void *tmp = t;
279 0 : t = r;
280 0 : r = tmp;
281 : }//r = r*s
282 : }
283 : }
284 0 : if(r != g){
285 0 : if(!nut_Poly_copy(g, r)){
286 : return 0;
287 : }
288 : }
289 : return true;
290 : }
291 :
292 0 : bool nut_Poly_pow_modn_tmptmp(nut_Poly *restrict g, const nut_Poly *f, uint64_t e, int64_t n, uint64_t cn){
293 0 : nut_Poly tmps[2] = {};
294 0 : bool status = true;
295 0 : for(uint64_t i = 0; status && i < 2; ++i){
296 0 : status = nut_Poly_init(tmps + i, 2*cn + 1);
297 : }
298 0 : if(status){
299 0 : status = nut_Poly_pow_modn(g, f, e, n, cn, tmps);
300 : }
301 0 : for(uint64_t i = 0; i < 2; ++i){
302 0 : if(tmps[i].cap){
303 0 : nut_Poly_destroy(tmps + i);
304 : }
305 : }
306 0 : return status;
307 : }
308 :
309 27 : bool nut_Poly_compose_modn(nut_Poly *restrict h, const nut_Poly *f, const nut_Poly *g, int64_t n, uint64_t cn, nut_Poly tmps[restrict static 2]){
310 27 : if(f->len == 1){
311 0 : return nut_Poly_setconst(h, f->coeffs[0]);
312 27 : }else if(g->len == 1){
313 0 : return nut_Poly_setconst(h, nut_Poly_eval_modn(f, g->coeffs[0], n));
314 : }
315 27 : uint64_t h_len = (f->len - 1)*(g->len - 1) + 1;
316 27 : if(h_len > cn + 1){
317 : h_len = cn + 1;
318 : }
319 27 : if(!nut_Poly_ensure_cap(h, h_len) || !nut_Poly_setconst(h, f->coeffs[0])){
320 0 : return false;
321 : }
322 27 : nut_Poly *t = tmps + 0, *p = tmps + 1;
323 27 : if(!nut_Poly_copy(p, g) || !nut_Poly_scale_modn(t, g, f->coeffs[1], n) || !nut_Poly_add_modn(h, h, t, n)){
324 0 : return false;
325 : }
326 54 : for(uint64_t e = 2; e <= cn && e < f->len; ++e){
327 27 : if(!nut_Poly_mul_modn(t, p, g, n)){
328 : return false;
329 : }
330 27 : nut_Poly_normalize_exps_modn(t, cn);
331 27 : void *tmp = t;
332 27 : t = p;
333 27 : p = tmp;
334 27 : if(!nut_Poly_scale_modn(t, p, f->coeffs[e], n) || !nut_Poly_add_modn(h, h, t, n)){
335 0 : return false;
336 : }
337 : }
338 : return true;
339 : }
340 :
341 0 : bool nut_Poly_compose_modn_tmptmp(nut_Poly *restrict h, const nut_Poly *f, const nut_Poly *g, int64_t n, uint64_t cn){
342 0 : nut_Poly tmps[2] = {};
343 0 : bool status = true;
344 0 : for(uint64_t i = 0; status && i < 2; ++i){
345 0 : status = nut_Poly_init(tmps + i, 2*cn + 1);
346 : }
347 0 : if(status){
348 0 : status = nut_Poly_compose_modn(h, f, g, n, cn, tmps);
349 : }
350 0 : for(uint64_t i = 0; i < 2; ++i){
351 0 : if(tmps[i].cap){
352 0 : nut_Poly_destroy(tmps + i);
353 : }
354 : }
355 0 : return status;
356 : }
357 :
358 171422 : bool nut_Poly_quotrem_modn(nut_Poly *restrict q, nut_Poly *restrict r, const nut_Poly *restrict f, const nut_Poly *restrict g, int64_t n){
359 171422 : int64_t a, d = nut_i64_egcd(g->coeffs[g->len - 1], n, &a, NULL);
360 171422 : if(d != 1){
361 : return false;//TODO: set divide by zero error, or set remainder (?)
362 : }
363 171422 : if(g->len == 1){//dividing by a scalar TODO: set divide by zero error if scalar is 0
364 7946 : return g->coeffs[0] && nut_Poly_setconst(r, 0) && nut_Poly_scale_modn(q, f, a, n);
365 : }
366 163476 : if(f->len < g->len){//dividing by a polynomial with higher degree
367 45303 : return nut_Poly_setconst(q, 0) && nut_Poly_copy(r, f);
368 : }
369 :
370 : //begin extended synthetic division
371 : //compute max length of quotient and remainder and extend their buffers if need be
372 118173 : q->len = f->len - g->len + 1;
373 118173 : r->len = g->len - 1;
374 118173 : if(!nut_Poly_ensure_cap(q, q->len) || !nut_Poly_ensure_cap(r, r->len)){
375 0 : return false;
376 : }
377 :
378 : //initialize column sums/coeffs of results
379 118173 : memcpy(r->coeffs, f->coeffs, r->len*sizeof(int64_t));
380 118173 : memcpy(q->coeffs, f->coeffs + r->len, q->len*sizeof(int64_t));
381 :
382 : //loop over quotient columns (coefficients in reverse order) which were initialized
383 : //to the first q->len dividend coefficients
384 633308 : for(uint64_t k = q->len; k > 0;){
385 515135 : --k;
386 : //finish the column by dividing the sum by the leading coefficient of the divisor
387 : //for monic divisors (aka most of the time) a will simply be 1 so we may optimize this
388 515135 : q->coeffs[k] = nut_i64_mod(q->coeffs[k]*a, n);
389 : //subtract the adjusted column sum times the sum of the divisor coefficients from the
390 : //remainder coefficients. the remainder coefficients were initialized to the last r->len
391 : //dividend coefficients. we start q->len - k columns after the current column we just finished.
392 : //if k == 0 we should go from coefficient 0 in the remainder to coefficient r->len - 1
393 : //so in general we should go from k to r->len - 1 (this can easily be an empty interval)
394 127508066 : for(uint64_t i = 0, j = k; j < r->len; ++i, ++j){
395 126992931 : r->coeffs[j] = nut_i64_mod(r->coeffs[j] - q->coeffs[k]*g->coeffs[i], n);
396 : }
397 : //j goes from max(k - g->len + 1, 0) to k - 1 (both inclusive)
398 : //i ends at g->len - 2 (inclusive) and should go through the same number of values
399 : //so i starts at g->len - 2 - (k - 1 - j) = g->len + j - k - 1
400 122472623 : for(uint64_t j = k > g->len - 1 ? k - g->len + 1 : 0, i = g->len + j - k - 1; i < g->len - 1; ++i, ++j){
401 121957488 : q->coeffs[j] = nut_i64_mod(q->coeffs[j] - q->coeffs[k]*g->coeffs[i], n);
402 : }
403 : }
404 118173 : nut_Poly_normalize(r);
405 118173 : return true;
406 : }
407 :
408 291400 : void nut_Poly_normalize(nut_Poly *f){
409 420198 : while(f->len > 1 && !f->coeffs[f->len - 1]){
410 128798 : --f->len;
411 : }
412 291400 : }
413 :
414 30 : void nut_Poly_normalize_modn(nut_Poly *f, int64_t n, bool use_negatives){
415 30 : int64_t offset = use_negatives ? (1-n)/2 : 0;
416 214 : for(uint64_t i = 0; i < f->len; ++i){
417 184 : f->coeffs[i] = offset + nut_i64_mod(f->coeffs[i] - offset, n);
418 : }
419 30 : nut_Poly_normalize(f);
420 30 : }
421 :
422 27 : void nut_Poly_normalize_exps_modn(nut_Poly *f, uint64_t cn){
423 129 : for(uint64_t i = f->len - 1; i > cn; --i){
424 102 : uint64_t j = 1 + (i - 1)%cn;
425 102 : f->coeffs[j] += f->coeffs[i];
426 : }
427 27 : if(f->len > cn + 1){
428 23 : f->len = cn + 1;
429 : }
430 27 : nut_Poly_normalize(f);
431 27 : }
432 :
433 25 : int nut_Poly_fprint(FILE *file, const nut_Poly *f, const char *vname, const char *add, const char *sub, const char *pow, bool descending){
434 25 : int res = 0;
435 176 : for(uint64_t i = 0; i < f->len; ++i){
436 151 : uint64_t j = descending ? f->len - 1 - i : i;
437 151 : int64_t coeff = f->coeffs[j];
438 151 : if(!coeff){
439 76 : continue;
440 : }
441 75 : if(res){
442 50 : res += fprintf(file, "%s", coeff > 0 ? add : sub);
443 50 : coeff = coeff < 0 ? -coeff : coeff;
444 50 : if(coeff != 1 || !j){
445 40 : res += fprintf(file, "%"PRId64, coeff);
446 : }
447 : }else{
448 25 : if(coeff != 1 || !j){
449 9 : if(coeff == -1 && j){
450 0 : res += fprintf(file, "-");
451 : }else{
452 9 : res += fprintf(file, "%"PRId64, coeff);
453 : }
454 : }
455 : }
456 49 : if(j){
457 54 : res += fprintf(file, "%s", vname);
458 : }
459 54 : if(j > 1){
460 45 : res += fprintf(file, "%s%"PRIu64, pow, j);
461 : }
462 : }
463 25 : return res;
464 : }
465 :
466 8 : static inline void skip_whitespace(const char *restrict *restrict _str){
467 10 : while(isspace(**_str)){
468 2 : ++*_str;
469 : }
470 8 : }
471 :
472 2 : static inline int parse_monomial(nut_Poly *restrict f, const char *restrict *restrict _str){
473 2 : const char *str = *_str;
474 2 : int64_t sign = 1, coeff = 0;
475 2 : uint64_t x = 0;
476 4 : while(*str == '+' || *str == '-' || isspace(*str)){
477 2 : if(*str == '-'){
478 0 : sign = -sign;
479 : }
480 2 : ++str;
481 : }
482 : bool need_vpow = true;
483 3 : while(isdigit(*str)){
484 1 : coeff = 10*coeff + *str - '0';
485 1 : need_vpow = false;
486 1 : ++str;
487 : }
488 2 : skip_whitespace(&str);
489 2 : if(need_vpow){
490 1 : coeff = 1;
491 : }
492 2 : if(!need_vpow && *str == '*'){
493 0 : ++str;
494 0 : need_vpow = true;
495 0 : skip_whitespace(&str);
496 : }
497 2 : bool have_vpow = false;
498 2 : if(strncmp(str, "mod", 3) || !isspace(str[3])){
499 2 : while(isalpha(*str)){
500 1 : have_vpow = true;
501 1 : ++str;
502 : }
503 : }
504 1 : if(have_vpow){
505 1 : skip_whitespace(&str);
506 1 : if(*str == '^' || !strncmp(str, "**", 2)){
507 1 : str += *str == '^' ? 1 : 2;
508 1 : skip_whitespace(&str);
509 1 : bool have_pow = false;
510 2 : while(isdigit(*str)){
511 1 : x = 10*x + *str - '0';
512 1 : have_pow = true;
513 1 : ++str;
514 : }
515 1 : if(!have_pow){
516 : return 0;
517 : }
518 : }else{
519 : x = 1;
520 : }
521 : }
522 2 : skip_whitespace(&str);
523 2 : if(!coeff){
524 0 : *_str = str;
525 0 : return 1;
526 : }
527 2 : if(!nut_Poly_zero_extend(f, x + 1)){
528 : return 0;
529 : }
530 2 : f->coeffs[x] += sign*coeff;
531 2 : *_str = str;
532 2 : return 1;
533 : }
534 :
535 1 : int nut_Poly_parse(nut_Poly *restrict f, int64_t *restrict n, const char *restrict str, const char *restrict *restrict end){
536 1 : if(!nut_Poly_setconst(f, 0) || !parse_monomial(f, &str)){
537 0 : return 0;
538 : }
539 2 : while(*str == '+' || *str == '-'){
540 1 : if(!parse_monomial(f, &str)){
541 0 : if(end){
542 0 : skip_whitespace(&str);
543 0 : *end = str;
544 : }
545 0 : nut_Poly_normalize(f);
546 0 : return 1;
547 : }
548 : }
549 1 : if(!strncmp(str, "mod", 3) && isspace(str[3])){
550 1 : str += 4;
551 1 : skip_whitespace(&str);
552 1 : int64_t _n = 0;
553 1 : bool have_mod = false;
554 2 : while(isdigit(*str)){
555 1 : _n = 10*_n + *str - '0';
556 1 : have_mod = true;
557 1 : ++str;
558 : }
559 1 : if(have_mod){
560 1 : *n = _n;
561 1 : if(end){
562 1 : skip_whitespace(&str);
563 1 : *end = str;
564 : }
565 1 : nut_Poly_normalize_modn(f, *n, 0);
566 1 : return 2;
567 : }else{
568 0 : str -= 4;
569 : }
570 : }
571 0 : if(end){
572 0 : skip_whitespace(&str);
573 0 : *end = str;
574 : }
575 0 : nut_Poly_normalize(f);
576 0 : return 1;
577 : }
578 :
579 21479 : bool nut_Poly_rand_modn(nut_Poly *f, uint64_t max_len, int64_t n){
580 21479 : if(!max_len){
581 0 : return nut_Poly_setconst(f, 0);
582 : }
583 21479 : if(!nut_Poly_ensure_cap(f, max_len)){
584 : return false;
585 : }
586 172207 : for(uint64_t i = 0; i < max_len; ++i){
587 150728 : f->coeffs[i] = nut_u64_rand(0, n);
588 : }
589 21479 : f->len = max_len;
590 21479 : nut_Poly_normalize(f);
591 21479 : return true;
592 : }
593 :
594 9950 : bool nut_Poly_gcd_modn(nut_Poly *restrict d, const nut_Poly *restrict f, const nut_Poly *restrict g, int64_t n, nut_Poly tmps[restrict static 3]){
595 : //tmps: qt, r0t, r1t
596 9950 : nut_Poly *tmp, *r0, *r1, *r2;
597 9950 : bool status = true;
598 9950 : if(g->len > f->len){//Ensure the degree of f is >= the degree of g
599 9949 : const nut_Poly *tmp = f;//Shadow the other tmp with a const version
600 9949 : f = g;
601 9949 : g = tmp;
602 : }
603 9950 : if(g->len == 1){//If one of the inputs is a constant, we either have the gcd of something and a unit or something and zero
604 854 : status = g->coeffs[0] ? nut_Poly_setconst(d, 1) : nut_Poly_copy(d, f);
605 854 : goto CLEANUP;
606 : }
607 :
608 9096 : r0 = d, r1 = tmps + 1, r2 = tmps + 2;
609 : //Unroll first two remainder calculations to prevent copying
610 9096 : if(!nut_Poly_quotrem_modn(tmps + 0, r0, f, g, n)){
611 0 : status = false;
612 0 : goto CLEANUP;
613 : }
614 9096 : if(r0->len == 1 && !r0->coeffs[0]){
615 564 : status = nut_Poly_copy(d, g);
616 564 : goto CLEANUP;
617 : }
618 :
619 8532 : if(!nut_Poly_quotrem_modn(tmps + 0, r1, g, r0, n)){
620 0 : status = false;
621 0 : goto CLEANUP;
622 : }
623 8532 : if(r1->len == 1 && !r1->coeffs[0]){
624 1558 : goto CLEANUP;// r0 == d in this case, no need to copy
625 : }
626 :
627 : //Euclidean algorithm: take remainders until we reach 0, then the last nonzero remainder is the gcd
628 22612 : while(1){
629 22612 : if(!nut_Poly_quotrem_modn(tmps + 0, r2, r0, r1, n)){
630 0 : status = false;
631 0 : goto CLEANUP;
632 : }
633 22612 : if(r2->len == 1 && !r2->coeffs[0]){
634 6974 : if(r1 != d){//Make sure the result is in the output and not a temporary
635 5692 : status = nut_Poly_copy(d, r1);
636 : }
637 5692 : goto CLEANUP;
638 : }
639 : tmp = r0;
640 : r0 = r1;
641 : r1 = r2;
642 : r2 = tmp;
643 : }
644 :
645 9950 : CLEANUP:;
646 9950 : if(status){//If the gcd was found, make it monic
647 9950 : int64_t a = d->coeffs[d->len - 1];
648 9950 : if(a == n - 1){//Inverse of -1 is always -1
649 10 : nut_Poly_scale_modn(d, d, a, n);
650 9940 : }else if(a > 1){//If the leading coefficient is 0 or 1 the gcd is already monic
651 7766 : int64_t c, g = nut_i64_egcd(a, n, &c, NULL);//This g shadows the input
652 7766 : if(g != 1){//The leading coefficient cannot be inverted mod n
653 0 : return false;
654 : }
655 7766 : nut_Poly_scale_modn(d, d, c, n);
656 : }
657 : }
658 : return status;
659 : }
660 :
661 2 : bool nut_Poly_gcd_modn_tmptmp(nut_Poly *restrict d, const nut_Poly *restrict f, const nut_Poly *restrict g, int64_t n){
662 2 : nut_Poly tmps[3] = {};
663 2 : uint64_t min_len, quot_len;
664 2 : if(f->len <= g->len){
665 1 : min_len = f->len;
666 1 : quot_len = g->len - f->len + 1;
667 : }else{
668 1 : min_len = g->len;
669 1 : quot_len = f->len - g->len + 1;
670 : }
671 4 : bool status = nut_Poly_init(tmps + 0, quot_len) &&
672 4 : nut_Poly_init(tmps + 1, min_len) &&
673 6 : nut_Poly_init(tmps + 2, min_len) &&
674 2 : nut_Poly_gcd_modn(d, f, g, n, tmps);
675 8 : for(uint64_t i = 0; i < 3; ++i){
676 6 : if(tmps[i].cap){
677 6 : nut_Poly_destroy(tmps + i);
678 : }
679 : }
680 2 : return status;
681 : }
682 :
683 7471 : bool nut_Poly_powmod_modn(nut_Poly *restrict h, const nut_Poly *restrict f, uint64_t e, const nut_Poly *restrict g, int64_t n, nut_Poly tmps[restrict static 3]){
684 : //tmps: qt, st, rt
685 7471 : if(g->len <= f->len){
686 0 : if(!nut_Poly_quotrem_modn(tmps + 0, tmps + 1, f, g, n)){
687 : return false;
688 : }
689 7471 : }else if(!nut_Poly_copy(tmps + 1, f)){
690 : return false;
691 : }
692 7471 : if(tmps[1].len == 1){
693 0 : if(!tmps[1].coeffs[0]){
694 0 : return nut_Poly_setconst(h, 0);
695 0 : }else if(!e){
696 0 : return nut_Poly_setconst(h, 1);
697 : }
698 : }
699 7471 : if(!nut_Poly_ensure_cap(tmps + 1, 2*g->len - 3) || !nut_Poly_ensure_cap(tmps + 2, 2*g->len - 3) || !nut_Poly_ensure_cap(h, 2*g->len - 3)){
700 0 : return false;
701 : }
702 :
703 : nut_Poly *t = h, *s = tmps + 1, *r = tmps + 2;
704 9850 : while(e%2 == 0){
705 2379 : nut_Poly_mul_modn(t, s, s, n);
706 2379 : if(!nut_Poly_quotrem_modn(tmps + 0, s, t, g, n)){
707 : return false;
708 : }//s = s*s%g
709 2379 : e >>= 1;
710 : }
711 7471 : if(!nut_Poly_copy(r, s)){
712 : return false;
713 : }
714 88095 : while((e >>= 1)){
715 80624 : nut_Poly_mul_modn(t, s, s, n);
716 80624 : if(!nut_Poly_quotrem_modn(tmps + 0, s, t, g, n)){
717 : return false;
718 : }//s = s*s%g
719 80624 : if(e%2){
720 43031 : nut_Poly_mul_modn(t, r, s, n);
721 43031 : if(!nut_Poly_quotrem_modn(tmps + 0, r, t, g, n)){
722 : return false;
723 : }//r = r*s%g
724 : }
725 : }
726 :
727 7471 : if(r != h){// TODO: this check is always true, fix buffer juggling so we never need to copy
728 7471 : if(!nut_Poly_copy(h, r)){
729 : return false;
730 : }
731 : }
732 : return true;
733 : }
734 :
735 2 : bool nut_Poly_powmod_modn_tmptmp(nut_Poly *restrict h, const nut_Poly *restrict f, uint64_t e, const nut_Poly *restrict g, int64_t n){
736 2 : nut_Poly tmps[3] = {};
737 2 : bool status = true;
738 8 : for(uint64_t i = 0; status && i < 3; ++i){
739 6 : status = nut_Poly_init(tmps + i, 2*g->len - 1);
740 : }
741 2 : if(status){
742 2 : status = nut_Poly_powmod_modn(h, f, e, g, n, tmps);
743 : }
744 8 : for(uint64_t i = 0; i < 3; ++i){
745 6 : if(tmps[i].cap){
746 6 : nut_Poly_destroy(tmps + i);
747 : }
748 : }
749 2 : return status;
750 : }
751 :
752 5000 : bool nut_Poly_factors_d_modn(nut_Poly *restrict f_d, const nut_Poly *restrict f, uint64_t d, int64_t n, nut_Poly tmps[restrict static 4]){
753 : //tmps: xt, qt, st, rt
754 5000 : if(!nut_Poly_ensure_cap(f_d, f->len)){
755 : return false;
756 : }
757 5000 : f_d->coeffs[0] = 0;
758 5000 : f_d->coeffs[1] = 1;
759 5000 : f_d->len = 2;
760 10000 : return nut_Poly_powmod_modn(tmps + 0, f_d, nut_u64_pow(n, d), f, n, tmps + 1) &&
761 10000 : nut_Poly_sub_modn(tmps + 0, tmps + 0, f_d, n) &&
762 5000 : nut_Poly_gcd_modn(f_d, tmps + 0, f, n, tmps + 1);
763 : }
764 :
765 2148 : bool nut_Poly_factor1_modn(nut_Poly *restrict g, const nut_Poly *restrict f, uint64_t d, int64_t n, nut_Poly tmps[restrict static 4]){
766 : //tmps: xt, qt, st, rt
767 2479 : while(1){
768 2479 : if(!nut_Poly_rand_modn(tmps + 0, f->len - 1, n) || !nut_Poly_gcd_modn(g, tmps + 0, f, n, tmps + 1)){
769 0 : return false;
770 : }
771 2479 : if(g->len > 1 && g->len < f->len){
772 : return true;
773 : }
774 2469 : if(
775 4938 : !nut_Poly_powmod_modn(g, tmps + 0, (nut_u64_pow(n, d)-1)/2, f, n, tmps + 1) ||
776 4938 : !nut_Poly_setconst(tmps + 0, 1) ||
777 4938 : !nut_Poly_add_modn(tmps + 0, g, tmps + 0, n) ||
778 2469 : !nut_Poly_gcd_modn(g, tmps + 0, f, n, tmps + 1)
779 : ){
780 0 : return false;
781 : }
782 2469 : if(g->len > 1 && g->len < f->len){
783 : return true;
784 : }
785 : }
786 : }
787 :
788 : [[gnu::nonnull(1, 3)]]
789 : NUT_ATTR_ACCESS(read_write, 1)
790 2671 : static inline bool roots_polyn_modn_rec(nut_Roots *restrict roots, int64_t n, nut_Poly tmps[restrict static 6]){
791 : //tmps: gt, ft, xt, qt, st, rt
792 2148 : while(1){
793 : //fprintf(stderr, "\e[1;33mFactoring (");
794 : //nut_Poly_fprint(stderr, tmps + 1, "x", " + ", " - ", "**", 0);
795 : //fprintf(stderr, ") mod %"PRId64"\e[0m\n", n);
796 4819 : if(tmps[1].len == 2){//linear factor
797 : //fprintf(stderr, "\e[1;33m polynomial is linear\e[0m\n");
798 1158 : if(tmps[1].coeffs[0]){//nonzero root
799 1158 : if(tmps[1].coeffs[1] == 1){//monic
800 1158 : roots->roots[roots->len++] = n - tmps[1].coeffs[0];
801 : }else{//not monic
802 0 : int64_t a;
803 0 : nut_i64_egcd(tmps[1].coeffs[0], n, &a, NULL);
804 0 : roots->roots[roots->len++] = nut_i64_mod(-tmps[1].coeffs[0]*a, n);
805 : }
806 : }else{//zero root
807 0 : roots->roots[roots->len++] = 0;
808 : }
809 2671 : return true;
810 3661 : }else if(tmps[1].len == 3){//quadratic factor
811 : //fprintf(stderr, "\e[1;33m polynomial is quadratic\e[0m\n");
812 1513 : if(tmps[1].coeffs[2] != 1){//make monic
813 0 : int64_t a;
814 0 : nut_i64_egcd(tmps[1].coeffs[2], n, &a, NULL);
815 0 : nut_Poly_scale_modn(tmps + 1, tmps + 1, a, n);
816 : }
817 1513 : int64_t c = tmps[1].coeffs[0];
818 1513 : int64_t b = tmps[1].coeffs[1];
819 1513 : int64_t r = nut_i64_sqrt_mod(nut_i64_mod(b*b - 4*c, n), n);
820 1513 : roots->roots[roots->len++] = nut_i64_mod((n+1)/2*(-b + r), n);
821 1513 : roots->roots[roots->len++] = nut_i64_mod((n+1)/2*(-b - r), n);
822 1513 : return true;
823 : }
824 2148 : if(!nut_Poly_factor1_modn(tmps + 0, tmps + 1, 1, n, tmps + 2) || !nut_Poly_quotrem_modn(tmps + 3, tmps + 5, tmps + 1, tmps + 0, n)){
825 0 : return false;
826 : }
827 : //tmps[0] and tmps[3] hold the nontrivial factors we found
828 : //we have to do a recursive call on the factor with smaller degree
829 : //if it is linear or quadratic, we do not need to copy the larger factor
830 : //otherwise we do
831 2148 : nut_Poly tmp;
832 2148 : tmp = tmps[0];
833 2148 : tmps[0] = tmps[1];
834 2148 : if(tmp.len <= tmps[3].len){
835 1289 : tmps[1] = tmp;
836 : }else{
837 859 : tmps[1] = tmps[3];
838 859 : tmps[3] = tmp;
839 : }
840 : //now, tmps[1] and tmps[3] hold the nontrivial factors, with tmps[1] having smallest degree
841 : //fprintf(stderr, "\e[1;33m got factors (");
842 : //nut_Poly_fprint(stderr, tmps + 1, "x", " + ", " - ", "**", 0);
843 : //fprintf(stderr, ")(");
844 : //nut_Poly_fprint(stderr, tmps + 3, "x", " + ", " - ", "**", 0);
845 : //fprintf(stderr, ")\e[0m\n");
846 2148 : bool have_small_factor = tmps[1].len <= 3;
847 2148 : if(!have_small_factor){
848 417 : if(!nut_Poly_init(&tmp, tmps[3].len) || !nut_Poly_copy(&tmp, tmps + 3)){
849 0 : return false;
850 : }
851 : }
852 2148 : if(!roots_polyn_modn_rec(roots, n, tmps)){
853 : return false;
854 : }
855 2148 : if(!have_small_factor){
856 417 : if(!nut_Poly_copy(tmps + 3, &tmp)){
857 : return false;
858 : }
859 417 : nut_Poly_destroy(&tmp);
860 : }
861 2148 : tmp = tmps[3];
862 2148 : tmps[3] = tmps[1];
863 2148 : tmps[1] = tmp;
864 : }
865 : }
866 :
867 5000 : bool nut_Poly_roots_modn(const nut_Poly *restrict f, int64_t n, nut_Roots *restrict roots, nut_Poly tmps[restrict static 6]){
868 : //tmps: gt, ft, xt, qt, st, rt
869 5000 : if(!nut_Poly_factors_d_modn(tmps + 1, f, 1, n, tmps + 2)){
870 : return false;
871 : }
872 5000 : roots->len = 0;
873 5000 : if(tmps[1].len == 1){
874 : return true;
875 : }
876 523 : return nut_Roots_ensure_cap(roots, tmps[1].len - 1) && roots_polyn_modn_rec(roots, n, tmps);
877 : }
878 :
879 5000 : bool nut_Poly_roots_modn_tmptmp(const nut_Poly *restrict f, int64_t n, nut_Roots *restrict roots){
880 5000 : nut_Poly tmps[6] = {};
881 5000 : bool status = true;
882 35000 : for(uint64_t i = 0; status && i < 6; ++i){
883 30000 : status = nut_Poly_init(tmps + i, f->len);
884 : }
885 5000 : if(status){
886 5000 : status = nut_Poly_roots_modn(f, n, roots, tmps);
887 : }
888 35000 : for(uint64_t i = 0; i < 6; ++i){
889 30000 : if(tmps[i].cap){
890 30000 : nut_Poly_destroy(tmps + i);
891 : }
892 : }
893 5000 : return status;
894 : }
895 :
|