About Docs Source
LCOV - code coverage report
Current view: top level - src/lib/nut - polynomial.c (source / functions) Coverage Total Hit
Test: unnamed Lines: 74.2 % 523 388
Test Date: 2025-10-22 01:14:28 Functions: 86.8 % 38 33

            Line data    Source code
       1              : #include <string.h>
       2              : #include <ctype.h>
       3              : 
       4              : #include <nut/modular_math.h>
       5              : #include <nut/factorization.h>
       6              : #include <nut/polynomial.h>
       7              : 
       8        30438 : bool nut_Poly_init(nut_Poly *f, uint64_t reserve){
       9        30438 :     reserve = reserve ?: 4;
      10        30438 :     f->coeffs = malloc(reserve*sizeof(int64_t));
      11        30438 :     if(!f->coeffs){
      12              :         return false;
      13              :     }
      14        30438 :     f->len = 1;
      15        30438 :     f->coeffs[0] = 0;
      16        30438 :     f->cap = reserve;
      17        30438 :     return true;
      18              : }
      19              : 
      20        30444 : void nut_Poly_destroy(nut_Poly *f){
      21        30444 :     free(f->coeffs);
      22        30444 :     *f = (nut_Poly){};
      23        30444 : }
      24              : 
      25            0 : int nut_Poly_cmp(const nut_Poly *a, const nut_Poly *b){
      26            0 :     uint64_t min_len;
      27            0 :     if(a->len < b->len){
      28            0 :         for(uint64_t i = b->len - 1; i >= a->len; --i){
      29            0 :             if(b->coeffs[i] > 0){
      30              :                 return -1;
      31            0 :             }else if(b->coeffs[i] < 0){
      32              :                 return 1;
      33            0 :             }else if(!i){
      34              :                 return 0;
      35              :             }
      36              :         }
      37            0 :         min_len = a->len;
      38            0 :     }else if(a->len > b->len){
      39            0 :         for(uint64_t i = a->len - 1; i >= b->len; --i){
      40            0 :             if(a->coeffs[i] > 0){
      41              :                 return 1;
      42            0 :             }else if(a->coeffs[i] < 0){
      43              :                 return -1;
      44            0 :             }else if(!i){
      45              :                 return 0;
      46              :             }
      47              :         }
      48            0 :         min_len = b->len;
      49              :     }else{
      50            0 :         min_len = a->len;
      51              :     }
      52            0 :     for(uint64_t i = min_len; i-- > 0;){
      53            0 :         if(a->coeffs[i] < b->coeffs[i]){
      54              :             return -1;
      55            0 :         }else if(a->coeffs[i] < b->coeffs[i]){
      56              :             return 1;
      57              :         }
      58              :     }
      59              :     return 0;
      60              : }
      61              : 
      62            1 : bool nut_Roots_init(nut_Roots *roots, uint64_t reserve){
      63            1 :     reserve = reserve ?: 4;
      64            1 :     roots->roots = malloc(reserve*sizeof(int64_t));
      65            1 :     if(!roots->roots){
      66              :         return false;
      67              :     }
      68            1 :     roots->len = 0;
      69            1 :     roots->cap = reserve;
      70            1 :     return true;
      71              : }
      72              : 
      73            1 : void nut_Roots_destroy(nut_Roots *roots){
      74            1 :     free(roots->roots);
      75            1 :     *roots = (nut_Roots){};
      76            1 : }
      77              : 
      78     25554069 : int64_t nut_Poly_eval_modn(const nut_Poly *f, int64_t x, int64_t n){
      79     25554069 :     if(!x){
      80         5001 :         return f->coeffs[0];
      81              :     }
      82     25549068 :     int64_t r = 0;
      83    255460664 :     for(uint64_t i = f->len; i > 0;){
      84    229911596 :         --i;
      85    229911596 :         r = nut_i64_mod(r*x + f->coeffs[i], n);
      86              :     }
      87              :     return r;
      88              : }
      89              : 
      90       580067 : bool nut_Poly_ensure_cap(nut_Poly *f, uint64_t cap){
      91       580067 :     if(f->cap < cap){
      92        15706 :         int64_t *tmp = realloc(f->coeffs, cap*sizeof(int64_t));
      93        15706 :         if(!tmp){
      94              :             return false;
      95              :         }
      96        15706 :         f->coeffs = tmp;
      97        15706 :         f->cap = cap;
      98              :     }
      99              :     return true;
     100              : }
     101              : 
     102            2 : bool nut_Poly_zero_extend(nut_Poly *f, uint64_t len){
     103            2 :     if(!nut_Poly_ensure_cap(f, len)){
     104              :         return false;
     105              :     }
     106            4 :     for(uint64_t i = f->len; i < len; ++i){
     107            2 :         f->coeffs[i] = false;
     108              :     }
     109            2 :     if(len > f->len){
     110            1 :         f->len = len;
     111              :     }
     112              :     return true;
     113              : }
     114              : 
     115          523 : bool nut_Roots_ensure_cap(nut_Roots *roots, uint64_t cap){
     116          523 :     if(roots->cap < cap){
     117            0 :         int64_t *tmp = realloc(roots->roots, cap*sizeof(int64_t));
     118            0 :         if(!tmp){
     119              :             return false;
     120              :         }
     121            0 :         roots->roots = tmp;
     122            0 :         roots->cap = cap;
     123              :     }
     124              :     return true;
     125              : }
     126              : 
     127         6523 : bool nut_Poly_add_modn(nut_Poly *h, const nut_Poly *f, const nut_Poly *g, int64_t n){
     128         6523 :     if(!nut_Poly_ensure_cap(h, f->len > g->len ? f->len : g->len)){
     129              :         return false;
     130              :     }
     131              :     uint64_t i;
     132        30046 :     for(i = 0; i < f->len && i < g->len; ++i){
     133        23523 :         h->coeffs[i] = nut_i64_mod(f->coeffs[i] + g->coeffs[i], n);
     134              :     }
     135        23909 :     for(; i < f->len; ++i){
     136        17386 :         h->coeffs[i] = f->coeffs[i];
     137              :     }
     138        10659 :     for(; i < g->len; ++i){
     139         4136 :         h->coeffs[i] = g->coeffs[i];
     140              :     }
     141         6523 :     int same_len = f->len == g->len;
     142         6523 :     h->len = f->len > g->len ? f->len : g->len;
     143         6523 :     if(same_len){
     144         1358 :         nut_Poly_normalize(h);
     145              :     }
     146              :     return true;
     147              : }
     148              : 
     149         8001 : bool nut_Poly_sub_modn(nut_Poly *h, const nut_Poly *f, const nut_Poly *g, int64_t n){
     150         8001 :     if(!nut_Poly_ensure_cap(h, f->len > g->len ? f->len : g->len)){
     151              :         return false;
     152              :     }
     153              :     uint64_t i;
     154        36003 :     for(i = 0; i < f->len && i < g->len; ++i){
     155        28002 :         h->coeffs[i] = nut_i64_mod(f->coeffs[i] - g->coeffs[i], n);
     156              :     }
     157        33857 :     for(; i < f->len; ++i){
     158        25856 :         h->coeffs[i] = f->coeffs[i];
     159              :     }
     160        12001 :     for(; i < g->len; ++i){
     161         4000 :         h->coeffs[i] = g->coeffs[i] ? n - g->coeffs[i] : 0;
     162              :     }
     163         8001 :     int same_len = f->len == g->len;
     164         8001 :     h->len = f->len > g->len ? f->len : g->len;
     165         8001 :     if(same_len){
     166         1523 :         nut_Poly_normalize(h);
     167              :     }
     168              :     return true;
     169              : }
     170              : 
     171            0 : bool nut_Poly_dot_modn(nut_Poly *h, const nut_Poly *f, const nut_Poly *g, int64_t n){
     172            0 :     if(!nut_Poly_ensure_cap(h, f->len < g->len ? f->len : g->len)){
     173              :         return false;
     174              :     }
     175              :     uint64_t i;
     176            0 :     for(i = 0; i < f->len && i < g->len; ++i){
     177            0 :         h->coeffs[i] = nut_i64_mod(f->coeffs[i]*g->coeffs[i], n);
     178              :     }
     179            0 :     h->len = f->len < g->len ? f->len : g->len;
     180            0 :     nut_Poly_normalize(h);
     181            0 :     return true;
     182              : }
     183              : 
     184        75522 : bool nut_Poly_copy(nut_Poly *restrict g, const nut_Poly *restrict f){
     185        75522 :     if(!nut_Poly_ensure_cap(g, f->len)){
     186              :         return false;
     187              :     }
     188        75522 :     memcpy(g->coeffs, f->coeffs, f->len*sizeof(int64_t));
     189        75522 :     g->len = f->len;
     190        75522 :     return true;
     191              : }
     192              : 
     193        55941 : bool nut_Poly_setconst(nut_Poly *f, int64_t c){
     194        55941 :     if(!nut_Poly_ensure_cap(f, 1)){
     195              :         return false;
     196              :     }
     197        55941 :     f->coeffs[0] = c;
     198        55941 :     f->len = 1;
     199        55941 :     return true;
     200              : }
     201              : 
     202        22510 : bool nut_Poly_scale_modn(nut_Poly *g, const nut_Poly *f, int64_t a, int64_t n){
     203        22510 :     if(!a){
     204           27 :         return nut_Poly_setconst(g, 0);
     205              :     }
     206        22483 :     if(!nut_Poly_ensure_cap(g, f->len)){
     207              :         return false;
     208              :     }
     209       114536 :     for(uint64_t i = 0; i < f->len; ++i){
     210        92053 :         g->coeffs[i] = nut_i64_mod(a*f->coeffs[i], n);
     211              :     }
     212        22483 :     g->len = f->len;
     213        22483 :     nut_Poly_normalize(g);
     214        22483 :     return true;
     215              : }
     216              : 
     217       130061 : bool nut_Poly_mul_modn(nut_Poly *restrict h, const nut_Poly *f, const nut_Poly *g, int64_t n){
     218       130061 :     if(f->len == 1){
     219         2734 :         return nut_Poly_scale_modn(h, g, f->coeffs[0], n);
     220       127327 :     }else if(g->len == 1){
     221         1000 :         return nut_Poly_scale_modn(h, f, g->coeffs[0], n);
     222              :     }
     223       126327 :     if(!nut_Poly_ensure_cap(h, f->len + g->len - 1)){
     224              :         return false;
     225              :     }
     226      1441366 :     for(uint64_t k = 0; k < f->len + g->len - 1; ++k){
     227      1315039 :         h->coeffs[k] = 0;
     228    258532940 :         for(uint64_t i = k >= g->len ? k - g->len + 1 : 0; i < f->len && i <= k; ++i){
     229    255902862 :             h->coeffs[k] = nut_i64_mod(h->coeffs[k] + f->coeffs[i]*g->coeffs[k - i], n);
     230              :         }
     231              :     }
     232       126327 :     h->len = f->len + g->len - 1;
     233       126327 :     nut_Poly_normalize(h);
     234       126327 :     return true;
     235              : }
     236              : 
     237            0 : bool nut_Poly_pow_modn(nut_Poly *restrict g, const nut_Poly *f, uint64_t e, int64_t n, uint64_t cn, nut_Poly tmps[restrict static 2]){
     238              :     //tmps: st, rt
     239            0 :     if(f->len > cn + 1){
     240              :         return false;
     241            0 :     }else if(!e){
     242            0 :         return nut_Poly_setconst(g, 1); // TODO: deal with 0**0 case
     243            0 :     }else if(e == 1){
     244            0 :         return nut_Poly_copy(g, f);
     245              :     }
     246            0 :     e = 1 + (e - 1)%cn;
     247            0 :     if(f->len == 1){
     248            0 :         return nut_Poly_setconst(g, nut_u64_pow(f->coeffs[0], e));
     249            0 :     }else if(!nut_Poly_copy(tmps + 0, f) || !nut_Poly_ensure_cap(tmps + 0, 2*cn + 1) || !nut_Poly_ensure_cap(tmps + 1, 2*cn + 1)){
     250            0 :         return false;
     251              :     }
     252            0 :     nut_Poly *t = g, *s = tmps + 0, *r = tmps + 1;
     253            0 :     while(e%2 == 0){
     254            0 :         nut_Poly_mul_modn(t, s, s, n);
     255            0 :         nut_Poly_normalize_exps_modn(t, cn);
     256              :         {
     257            0 :             void *tmp = t;
     258            0 :             t = s;
     259            0 :             s = tmp;
     260              :         }//s = s*s
     261            0 :         e >>= 1;
     262              :     }
     263            0 :     if(!nut_Poly_copy(r, s)){
     264              :         return false;
     265              :     }
     266            0 :     while((e >>= 1)){
     267            0 :         nut_Poly_mul_modn(t, s, s, n);
     268            0 :         nut_Poly_normalize_exps_modn(t, cn);
     269              :         {
     270            0 :             void *tmp = t;
     271            0 :             t = s;
     272            0 :             s = tmp;
     273              :         }//s = s*s
     274            0 :         if(e%2){
     275            0 :             nut_Poly_mul_modn(t, r, s, n);
     276            0 :             nut_Poly_normalize_exps_modn(t, cn);
     277              :             {
     278            0 :                 void *tmp = t;
     279            0 :                 t = r;
     280            0 :                 r = tmp;
     281              :             }//r = r*s
     282              :         }
     283              :     }
     284            0 :     if(r != g){
     285            0 :         if(!nut_Poly_copy(g, r)){
     286              :             return 0;
     287              :         }
     288              :     }
     289              :     return true;
     290              : }
     291              : 
     292            0 : bool nut_Poly_pow_modn_tmptmp(nut_Poly *restrict g, const nut_Poly *f, uint64_t e, int64_t n, uint64_t cn){
     293            0 :     nut_Poly tmps[2] = {};
     294            0 :     bool status = true;
     295            0 :     for(uint64_t i = 0; status && i < 2; ++i){
     296            0 :         status = nut_Poly_init(tmps + i, 2*cn + 1);
     297              :     }
     298            0 :     if(status){
     299            0 :         status = nut_Poly_pow_modn(g, f, e, n, cn, tmps);
     300              :     }
     301            0 :     for(uint64_t i = 0; i < 2; ++i){
     302            0 :         if(tmps[i].cap){
     303            0 :             nut_Poly_destroy(tmps + i);
     304              :         }
     305              :     }
     306            0 :     return status;
     307              : }
     308              : 
     309           27 : bool nut_Poly_compose_modn(nut_Poly *restrict h, const nut_Poly *f, const nut_Poly *g, int64_t n, uint64_t cn, nut_Poly tmps[restrict static 2]){
     310           27 :     if(f->len == 1){
     311            0 :         return nut_Poly_setconst(h, f->coeffs[0]);
     312           27 :     }else if(g->len == 1){
     313            0 :         return nut_Poly_setconst(h, nut_Poly_eval_modn(f, g->coeffs[0], n));
     314              :     }
     315           27 :     uint64_t h_len = (f->len - 1)*(g->len - 1) + 1;
     316           27 :     if(h_len > cn + 1){
     317              :         h_len = cn + 1;
     318              :     }
     319           27 :     if(!nut_Poly_ensure_cap(h, h_len) || !nut_Poly_setconst(h, f->coeffs[0])){
     320            0 :         return false;
     321              :     }
     322           27 :     nut_Poly *t = tmps + 0, *p = tmps + 1;
     323           27 :     if(!nut_Poly_copy(p, g) || !nut_Poly_scale_modn(t, g, f->coeffs[1], n) || !nut_Poly_add_modn(h, h, t, n)){
     324            0 :         return false;
     325              :     }
     326           54 :     for(uint64_t e = 2; e <= cn && e < f->len; ++e){
     327           27 :         if(!nut_Poly_mul_modn(t, p, g, n)){
     328              :             return false;
     329              :         }
     330           27 :         nut_Poly_normalize_exps_modn(t, cn);
     331           27 :         void *tmp = t;
     332           27 :         t = p;
     333           27 :         p = tmp;
     334           27 :         if(!nut_Poly_scale_modn(t, p, f->coeffs[e], n) || !nut_Poly_add_modn(h, h, t, n)){
     335            0 :             return false;
     336              :         }
     337              :     }
     338              :     return true;
     339              : }
     340              : 
     341            0 : bool nut_Poly_compose_modn_tmptmp(nut_Poly *restrict h, const nut_Poly *f, const nut_Poly *g, int64_t n, uint64_t cn){
     342            0 :     nut_Poly tmps[2] = {};
     343            0 :     bool status = true;
     344            0 :     for(uint64_t i = 0; status && i < 2; ++i){
     345            0 :         status = nut_Poly_init(tmps + i, 2*cn + 1);
     346              :     }
     347            0 :     if(status){
     348            0 :         status = nut_Poly_compose_modn(h, f, g, n, cn, tmps);
     349              :     }
     350            0 :     for(uint64_t i = 0; i < 2; ++i){
     351            0 :         if(tmps[i].cap){
     352            0 :             nut_Poly_destroy(tmps + i);
     353              :         }
     354              :     }
     355            0 :     return status;
     356              : }
     357              : 
     358       171422 : bool nut_Poly_quotrem_modn(nut_Poly *restrict q, nut_Poly *restrict r, const nut_Poly *restrict f, const nut_Poly *restrict g, int64_t n){
     359       171422 :     int64_t a, d = nut_i64_egcd(g->coeffs[g->len - 1], n, &a, NULL);
     360       171422 :     if(d != 1){
     361              :         return false;//TODO: set divide by zero error, or set remainder (?)
     362              :     }
     363       171422 :     if(g->len == 1){//dividing by a scalar TODO: set divide by zero error if scalar is 0
     364         7946 :         return g->coeffs[0] && nut_Poly_setconst(r, 0) && nut_Poly_scale_modn(q, f, a, n);
     365              :     }
     366       163476 :     if(f->len < g->len){//dividing by a polynomial with higher degree
     367        45303 :         return nut_Poly_setconst(q, 0) && nut_Poly_copy(r, f);
     368              :     }
     369              :     
     370              :     //begin extended synthetic division
     371              :     //compute max length of quotient and remainder and extend their buffers if need be
     372       118173 :     q->len = f->len - g->len + 1;
     373       118173 :     r->len = g->len - 1;
     374       118173 :     if(!nut_Poly_ensure_cap(q, q->len) || !nut_Poly_ensure_cap(r, r->len)){
     375            0 :         return false;
     376              :     }
     377              :     
     378              :     //initialize column sums/coeffs of results
     379       118173 :     memcpy(r->coeffs, f->coeffs, r->len*sizeof(int64_t));
     380       118173 :     memcpy(q->coeffs, f->coeffs + r->len, q->len*sizeof(int64_t));
     381              :     
     382              :     //loop over quotient columns (coefficients in reverse order) which were initialized
     383              :     //to the first q->len dividend coefficients
     384       633308 :     for(uint64_t k = q->len; k > 0;){
     385       515135 :         --k;
     386              :         //finish the column by dividing the sum by the leading coefficient of the divisor
     387              :         //for monic divisors (aka most of the time) a will simply be 1 so we may optimize this
     388       515135 :         q->coeffs[k] = nut_i64_mod(q->coeffs[k]*a, n);
     389              :         //subtract the adjusted column sum times the sum of the divisor coefficients from the
     390              :         //remainder coefficients.  the remainder coefficients were initialized to the last r->len
     391              :         //dividend coefficients.  we start q->len - k columns after the current column we just finished.
     392              :         //if k == 0 we should go from coefficient 0 in the remainder to coefficient r->len - 1
     393              :         //so in general we should go from k to r->len - 1 (this can easily be an empty interval)
     394    127508066 :         for(uint64_t i = 0, j = k; j < r->len; ++i, ++j){
     395    126992931 :             r->coeffs[j] = nut_i64_mod(r->coeffs[j] - q->coeffs[k]*g->coeffs[i], n);
     396              :         }
     397              :         //j goes from max(k - g->len + 1, 0) to k - 1 (both inclusive)
     398              :         //i ends at g->len - 2 (inclusive) and should go through the same number of values
     399              :         //so i starts at g->len - 2 - (k - 1 - j) = g->len + j - k - 1
     400    122472623 :         for(uint64_t j = k > g->len - 1 ? k - g->len + 1 : 0, i = g->len + j - k - 1; i < g->len - 1; ++i, ++j){
     401    121957488 :             q->coeffs[j] = nut_i64_mod(q->coeffs[j] - q->coeffs[k]*g->coeffs[i], n);
     402              :         }
     403              :     }
     404       118173 :     nut_Poly_normalize(r);
     405       118173 :     return true;
     406              : }
     407              : 
     408       291400 : void nut_Poly_normalize(nut_Poly *f){
     409       420198 :     while(f->len > 1 && !f->coeffs[f->len - 1]){
     410       128798 :         --f->len;
     411              :     }
     412       291400 : }
     413              : 
     414           30 : void nut_Poly_normalize_modn(nut_Poly *f, int64_t n, bool use_negatives){
     415           30 :     int64_t offset = use_negatives ? (1-n)/2 : 0;
     416          214 :     for(uint64_t i = 0; i < f->len; ++i){
     417          184 :         f->coeffs[i] = offset + nut_i64_mod(f->coeffs[i] - offset, n);
     418              :     }
     419           30 :     nut_Poly_normalize(f);
     420           30 : }
     421              : 
     422           27 : void nut_Poly_normalize_exps_modn(nut_Poly *f, uint64_t cn){
     423          129 :     for(uint64_t i = f->len - 1; i > cn; --i){
     424          102 :         uint64_t j = 1 + (i - 1)%cn;
     425          102 :         f->coeffs[j] += f->coeffs[i];
     426              :     }
     427           27 :     if(f->len > cn + 1){
     428           23 :         f->len = cn + 1;
     429              :     }
     430           27 :     nut_Poly_normalize(f);
     431           27 : }
     432              : 
     433           25 : int nut_Poly_fprint(FILE *file, const nut_Poly *f, const char *vname, const char *add, const char *sub, const char *pow, bool descending){
     434           25 :     int res = 0;
     435          176 :     for(uint64_t i = 0; i < f->len; ++i){
     436          151 :         uint64_t j = descending ? f->len - 1 - i : i;
     437          151 :         int64_t coeff = f->coeffs[j];
     438          151 :         if(!coeff){
     439           76 :             continue;
     440              :         }
     441           75 :         if(res){
     442           50 :             res += fprintf(file, "%s", coeff > 0 ? add : sub);
     443           50 :             coeff = coeff < 0 ? -coeff : coeff;
     444           50 :             if(coeff != 1 || !j){
     445           40 :                 res += fprintf(file, "%"PRId64, coeff);
     446              :             }
     447              :         }else{
     448           25 :             if(coeff != 1 || !j){
     449            9 :                 if(coeff == -1 && j){
     450            0 :                     res += fprintf(file, "-");
     451              :                 }else{
     452            9 :                     res += fprintf(file, "%"PRId64, coeff);
     453              :                 }
     454              :             }
     455              :         }
     456           49 :         if(j){
     457           54 :             res += fprintf(file, "%s", vname);
     458              :         }
     459           54 :         if(j > 1){
     460           45 :             res += fprintf(file, "%s%"PRIu64, pow, j);
     461              :         }
     462              :     }
     463           25 :     return res;
     464              : }
     465              : 
     466            8 : static inline void skip_whitespace(const char *restrict *restrict _str){
     467           10 :     while(isspace(**_str)){
     468            2 :         ++*_str;
     469              :     }
     470            8 : }
     471              : 
     472            2 : static inline int parse_monomial(nut_Poly *restrict f, const char *restrict *restrict _str){
     473            2 :     const char *str = *_str;
     474            2 :     int64_t sign = 1, coeff = 0;
     475            2 :     uint64_t x = 0;
     476            4 :     while(*str == '+' || *str == '-' || isspace(*str)){
     477            2 :         if(*str == '-'){
     478            0 :             sign = -sign;
     479              :         }
     480            2 :         ++str;
     481              :     }
     482              :     bool need_vpow = true;
     483            3 :     while(isdigit(*str)){
     484            1 :         coeff = 10*coeff + *str - '0';
     485            1 :         need_vpow = false;
     486            1 :         ++str;
     487              :     }
     488            2 :     skip_whitespace(&str);
     489            2 :     if(need_vpow){
     490            1 :         coeff = 1;
     491              :     }
     492            2 :     if(!need_vpow && *str == '*'){
     493            0 :         ++str;
     494            0 :         need_vpow = true;
     495            0 :         skip_whitespace(&str);
     496              :     }
     497            2 :     bool have_vpow = false;
     498            2 :     if(strncmp(str, "mod", 3) || !isspace(str[3])){
     499            2 :         while(isalpha(*str)){
     500            1 :             have_vpow = true;
     501            1 :             ++str;
     502              :         }
     503              :     }
     504            1 :     if(have_vpow){
     505            1 :         skip_whitespace(&str);
     506            1 :         if(*str == '^' || !strncmp(str, "**", 2)){
     507            1 :             str += *str == '^' ? 1 : 2;
     508            1 :             skip_whitespace(&str);
     509            1 :             bool have_pow = false;
     510            2 :             while(isdigit(*str)){
     511            1 :                 x = 10*x + *str - '0';
     512            1 :                 have_pow = true;
     513            1 :                 ++str;
     514              :             }
     515            1 :             if(!have_pow){
     516              :                 return 0;
     517              :             }
     518              :         }else{
     519              :             x = 1;
     520              :         }
     521              :     }
     522            2 :     skip_whitespace(&str);
     523            2 :     if(!coeff){
     524            0 :         *_str = str;
     525            0 :         return 1;
     526              :     }
     527            2 :     if(!nut_Poly_zero_extend(f, x + 1)){
     528              :         return 0;
     529              :     }
     530            2 :     f->coeffs[x] += sign*coeff;
     531            2 :     *_str = str;
     532            2 :     return 1;
     533              : }
     534              : 
     535            1 : int nut_Poly_parse(nut_Poly *restrict f, int64_t *restrict n, const char *restrict str, const char *restrict *restrict end){
     536            1 :     if(!nut_Poly_setconst(f, 0) || !parse_monomial(f, &str)){
     537            0 :         return 0;
     538              :     }
     539            2 :     while(*str == '+' || *str == '-'){
     540            1 :         if(!parse_monomial(f, &str)){
     541            0 :             if(end){
     542            0 :                 skip_whitespace(&str);
     543            0 :                 *end = str;
     544              :             }
     545            0 :             nut_Poly_normalize(f);
     546            0 :             return 1;
     547              :         }
     548              :     }
     549            1 :     if(!strncmp(str, "mod", 3) && isspace(str[3])){
     550            1 :         str += 4;
     551            1 :         skip_whitespace(&str);
     552            1 :         int64_t _n = 0;
     553            1 :         bool have_mod = false;
     554            2 :         while(isdigit(*str)){
     555            1 :             _n = 10*_n + *str - '0';
     556            1 :             have_mod = true;
     557            1 :             ++str;
     558              :         }
     559            1 :         if(have_mod){
     560            1 :             *n = _n;
     561            1 :             if(end){
     562            1 :                 skip_whitespace(&str);
     563            1 :                 *end = str;
     564              :             }
     565            1 :             nut_Poly_normalize_modn(f, *n, 0);
     566            1 :             return 2;
     567              :         }else{
     568            0 :             str -= 4;
     569              :         }
     570              :     }
     571            0 :     if(end){
     572            0 :         skip_whitespace(&str);
     573            0 :         *end = str;
     574              :     }
     575            0 :     nut_Poly_normalize(f);
     576            0 :     return 1;
     577              : }
     578              : 
     579        21479 : bool nut_Poly_rand_modn(nut_Poly *f, uint64_t max_len, int64_t n){
     580        21479 :     if(!max_len){
     581            0 :         return nut_Poly_setconst(f, 0);
     582              :     }
     583        21479 :     if(!nut_Poly_ensure_cap(f, max_len)){
     584              :         return false;
     585              :     }
     586       172207 :     for(uint64_t i = 0; i < max_len; ++i){
     587       150728 :         f->coeffs[i] = nut_u64_rand(0, n);
     588              :     }
     589        21479 :     f->len = max_len;
     590        21479 :     nut_Poly_normalize(f);
     591        21479 :     return true;
     592              : }
     593              : 
     594         9950 : bool nut_Poly_gcd_modn(nut_Poly *restrict d, const nut_Poly *restrict f, const nut_Poly *restrict g, int64_t n, nut_Poly tmps[restrict static 3]){
     595              :     //tmps: qt, r0t, r1t
     596         9950 :     nut_Poly *tmp, *r0, *r1, *r2;
     597         9950 :     bool status = true;
     598         9950 :     if(g->len > f->len){//Ensure the degree of f is >= the degree of g
     599         9949 :         const nut_Poly *tmp = f;//Shadow the other tmp with a const version
     600         9949 :         f = g;
     601         9949 :         g = tmp;
     602              :     }
     603         9950 :     if(g->len == 1){//If one of the inputs is a constant, we either have the gcd of something and a unit or something and zero
     604          854 :         status = g->coeffs[0] ? nut_Poly_setconst(d, 1) : nut_Poly_copy(d, f);
     605          854 :         goto CLEANUP;
     606              :     }
     607              :     
     608         9096 :     r0 = d, r1 = tmps + 1, r2 = tmps + 2;
     609              :     //Unroll first two remainder calculations to prevent copying
     610         9096 :     if(!nut_Poly_quotrem_modn(tmps + 0, r0, f, g, n)){
     611            0 :         status = false;
     612            0 :         goto CLEANUP;
     613              :     }
     614         9096 :     if(r0->len == 1 && !r0->coeffs[0]){
     615          564 :         status = nut_Poly_copy(d, g);
     616          564 :         goto CLEANUP;
     617              :     }
     618              :     
     619         8532 :     if(!nut_Poly_quotrem_modn(tmps + 0, r1, g, r0, n)){
     620            0 :         status = false;
     621            0 :         goto CLEANUP;
     622              :     }
     623         8532 :     if(r1->len == 1 && !r1->coeffs[0]){
     624         1558 :         goto CLEANUP;// r0 == d in this case, no need to copy
     625              :     }
     626              :     
     627              :     //Euclidean algorithm: take remainders until we reach 0, then the last nonzero remainder is the gcd
     628        22612 :     while(1){
     629        22612 :         if(!nut_Poly_quotrem_modn(tmps + 0, r2, r0, r1, n)){
     630            0 :             status = false;
     631            0 :             goto CLEANUP;
     632              :         }
     633        22612 :         if(r2->len == 1 && !r2->coeffs[0]){
     634         6974 :             if(r1 != d){//Make sure the result is in the output and not a temporary
     635         5692 :                 status = nut_Poly_copy(d, r1);
     636              :             }
     637         5692 :             goto CLEANUP;
     638              :         }
     639              :         tmp = r0;
     640              :         r0 = r1;
     641              :         r1 = r2;
     642              :         r2 = tmp;
     643              :     }
     644              :     
     645         9950 :     CLEANUP:;
     646         9950 :     if(status){//If the gcd was found, make it monic
     647         9950 :         int64_t a = d->coeffs[d->len - 1];
     648         9950 :         if(a == n - 1){//Inverse of -1 is always -1
     649           10 :             nut_Poly_scale_modn(d, d, a, n);
     650         9940 :         }else if(a > 1){//If the leading coefficient is 0 or 1 the gcd is already monic
     651         7766 :             int64_t c, g = nut_i64_egcd(a, n, &c, NULL);//This g shadows the input
     652         7766 :             if(g != 1){//The leading coefficient cannot be inverted mod n
     653            0 :                 return false;
     654              :             }
     655         7766 :             nut_Poly_scale_modn(d, d, c, n);
     656              :         }
     657              :     }
     658              :     return status;
     659              : }
     660              : 
     661            2 : bool nut_Poly_gcd_modn_tmptmp(nut_Poly *restrict d, const nut_Poly *restrict f, const nut_Poly *restrict g, int64_t n){
     662            2 :     nut_Poly tmps[3] = {};
     663            2 :     uint64_t min_len, quot_len;
     664            2 :     if(f->len <= g->len){
     665            1 :         min_len = f->len;
     666            1 :         quot_len = g->len - f->len + 1;
     667              :     }else{
     668            1 :         min_len = g->len;
     669            1 :         quot_len = f->len - g->len + 1;
     670              :     }
     671            4 :     bool status = nut_Poly_init(tmps + 0, quot_len) &&
     672            4 :              nut_Poly_init(tmps + 1, min_len) &&
     673            6 :              nut_Poly_init(tmps + 2, min_len) &&
     674            2 :              nut_Poly_gcd_modn(d, f, g, n, tmps);
     675            8 :     for(uint64_t i = 0; i < 3; ++i){
     676            6 :         if(tmps[i].cap){
     677            6 :             nut_Poly_destroy(tmps + i);
     678              :         }
     679              :     }
     680            2 :     return status;
     681              : }
     682              : 
     683         7471 : bool nut_Poly_powmod_modn(nut_Poly *restrict h, const nut_Poly *restrict f, uint64_t e, const nut_Poly *restrict g, int64_t n, nut_Poly tmps[restrict static 3]){
     684              :     //tmps: qt, st, rt
     685         7471 :     if(g->len <= f->len){
     686            0 :         if(!nut_Poly_quotrem_modn(tmps + 0, tmps + 1, f, g, n)){
     687              :             return false;
     688              :         }
     689         7471 :     }else if(!nut_Poly_copy(tmps + 1, f)){
     690              :         return false;
     691              :     }
     692         7471 :     if(tmps[1].len == 1){
     693            0 :         if(!tmps[1].coeffs[0]){
     694            0 :             return nut_Poly_setconst(h, 0);
     695            0 :         }else if(!e){
     696            0 :             return nut_Poly_setconst(h, 1);
     697              :         }
     698              :     }
     699         7471 :     if(!nut_Poly_ensure_cap(tmps + 1, 2*g->len - 3) || !nut_Poly_ensure_cap(tmps + 2, 2*g->len - 3) || !nut_Poly_ensure_cap(h, 2*g->len - 3)){
     700            0 :         return false;
     701              :     }
     702              :     
     703              :     nut_Poly *t = h, *s = tmps + 1, *r = tmps + 2;
     704         9850 :     while(e%2 == 0){
     705         2379 :         nut_Poly_mul_modn(t, s, s, n);
     706         2379 :         if(!nut_Poly_quotrem_modn(tmps + 0, s, t, g, n)){
     707              :             return false;
     708              :         }//s = s*s%g
     709         2379 :         e >>= 1;
     710              :     }
     711         7471 :     if(!nut_Poly_copy(r, s)){
     712              :         return false;
     713              :     }
     714        88095 :     while((e >>= 1)){
     715        80624 :         nut_Poly_mul_modn(t, s, s, n);
     716        80624 :         if(!nut_Poly_quotrem_modn(tmps + 0, s, t, g, n)){
     717              :             return false;
     718              :         }//s = s*s%g
     719        80624 :         if(e%2){
     720        43031 :             nut_Poly_mul_modn(t, r, s, n);
     721        43031 :             if(!nut_Poly_quotrem_modn(tmps + 0, r, t, g, n)){
     722              :                 return false;
     723              :             }//r = r*s%g
     724              :         }
     725              :     }
     726              :     
     727         7471 :     if(r != h){// TODO: this check is always true, fix buffer juggling so we never need to copy
     728         7471 :         if(!nut_Poly_copy(h, r)){
     729              :             return false;
     730              :         }
     731              :     }
     732              :     return true;
     733              : }
     734              : 
     735            2 : bool nut_Poly_powmod_modn_tmptmp(nut_Poly *restrict h, const nut_Poly *restrict f, uint64_t e, const nut_Poly *restrict g, int64_t n){
     736            2 :     nut_Poly tmps[3] = {};
     737            2 :     bool status = true;
     738            8 :     for(uint64_t i = 0; status && i < 3; ++i){
     739            6 :         status = nut_Poly_init(tmps + i, 2*g->len - 1);
     740              :     }
     741            2 :     if(status){
     742            2 :         status = nut_Poly_powmod_modn(h, f, e, g, n, tmps);
     743              :     }
     744            8 :     for(uint64_t i = 0; i < 3; ++i){
     745            6 :         if(tmps[i].cap){
     746            6 :             nut_Poly_destroy(tmps + i);
     747              :         }
     748              :     }
     749            2 :     return status;
     750              : }
     751              : 
     752         5000 : bool nut_Poly_factors_d_modn(nut_Poly *restrict f_d, const nut_Poly *restrict f, uint64_t d, int64_t n, nut_Poly tmps[restrict static 4]){
     753              :     //tmps: xt, qt, st, rt
     754         5000 :     if(!nut_Poly_ensure_cap(f_d, f->len)){
     755              :         return false;
     756              :     }
     757         5000 :     f_d->coeffs[0] = 0;
     758         5000 :     f_d->coeffs[1] = 1;
     759         5000 :     f_d->len = 2;
     760        10000 :     return nut_Poly_powmod_modn(tmps + 0, f_d, nut_u64_pow(n, d), f, n, tmps + 1) &&
     761        10000 :         nut_Poly_sub_modn(tmps + 0, tmps + 0, f_d, n) &&
     762         5000 :         nut_Poly_gcd_modn(f_d, tmps + 0, f, n, tmps + 1);
     763              : }
     764              : 
     765         2148 : bool nut_Poly_factor1_modn(nut_Poly *restrict g, const nut_Poly *restrict f, uint64_t d, int64_t n, nut_Poly tmps[restrict static 4]){
     766              :     //tmps: xt, qt, st, rt
     767         2479 :     while(1){
     768         2479 :         if(!nut_Poly_rand_modn(tmps + 0, f->len - 1, n) || !nut_Poly_gcd_modn(g, tmps + 0, f, n, tmps + 1)){
     769            0 :             return false;
     770              :         }
     771         2479 :         if(g->len > 1 && g->len < f->len){
     772              :             return true;
     773              :         }
     774         2469 :         if(
     775         4938 :             !nut_Poly_powmod_modn(g, tmps + 0, (nut_u64_pow(n, d)-1)/2, f, n, tmps + 1) ||
     776         4938 :             !nut_Poly_setconst(tmps + 0, 1) ||
     777         4938 :             !nut_Poly_add_modn(tmps + 0, g, tmps + 0, n) ||
     778         2469 :             !nut_Poly_gcd_modn(g, tmps + 0, f, n, tmps + 1)
     779              :         ){
     780            0 :             return false;
     781              :         }
     782         2469 :         if(g->len > 1 && g->len < f->len){
     783              :             return true;
     784              :         }
     785              :     }
     786              : }
     787              : 
     788              : [[gnu::nonnull(1, 3)]]
     789              : NUT_ATTR_ACCESS(read_write, 1)
     790         2671 : static inline bool roots_polyn_modn_rec(nut_Roots *restrict roots, int64_t n, nut_Poly tmps[restrict static 6]){
     791              :     //tmps: gt, ft, xt, qt, st, rt
     792         2148 :     while(1){
     793              :         //fprintf(stderr, "\e[1;33mFactoring (");
     794              :         //nut_Poly_fprint(stderr, tmps + 1, "x", " + ", " - ", "**", 0);
     795              :         //fprintf(stderr, ") mod %"PRId64"\e[0m\n", n);
     796         4819 :         if(tmps[1].len == 2){//linear factor
     797              :             //fprintf(stderr, "\e[1;33m polynomial is linear\e[0m\n");
     798         1158 :             if(tmps[1].coeffs[0]){//nonzero root
     799         1158 :                 if(tmps[1].coeffs[1] == 1){//monic
     800         1158 :                     roots->roots[roots->len++] = n - tmps[1].coeffs[0];
     801              :                 }else{//not monic
     802            0 :                     int64_t a;
     803            0 :                     nut_i64_egcd(tmps[1].coeffs[0], n, &a, NULL);
     804            0 :                     roots->roots[roots->len++] = nut_i64_mod(-tmps[1].coeffs[0]*a, n);
     805              :                 }
     806              :             }else{//zero root
     807            0 :                 roots->roots[roots->len++] = 0;
     808              :             }
     809         2671 :             return true;
     810         3661 :         }else if(tmps[1].len == 3){//quadratic factor
     811              :             //fprintf(stderr, "\e[1;33m polynomial is quadratic\e[0m\n");
     812         1513 :             if(tmps[1].coeffs[2] != 1){//make monic
     813            0 :                 int64_t a;
     814            0 :                 nut_i64_egcd(tmps[1].coeffs[2], n, &a, NULL);
     815            0 :                 nut_Poly_scale_modn(tmps + 1, tmps + 1, a, n);
     816              :             }
     817         1513 :             int64_t c = tmps[1].coeffs[0];
     818         1513 :             int64_t b = tmps[1].coeffs[1];
     819         1513 :             int64_t r = nut_i64_sqrt_mod(nut_i64_mod(b*b - 4*c, n), n);
     820         1513 :             roots->roots[roots->len++] = nut_i64_mod((n+1)/2*(-b + r), n);
     821         1513 :             roots->roots[roots->len++] = nut_i64_mod((n+1)/2*(-b - r), n);
     822         1513 :             return true;
     823              :         }
     824         2148 :         if(!nut_Poly_factor1_modn(tmps + 0, tmps + 1, 1, n, tmps + 2) || !nut_Poly_quotrem_modn(tmps + 3, tmps + 5, tmps + 1, tmps + 0, n)){
     825            0 :             return false;
     826              :         }
     827              :         //tmps[0] and tmps[3] hold the nontrivial factors we found
     828              :         //we have to do a recursive call on the factor with smaller degree
     829              :         //if it is linear or quadratic, we do not need to copy the larger factor
     830              :         //otherwise we do
     831         2148 :         nut_Poly tmp;
     832         2148 :         tmp = tmps[0];
     833         2148 :         tmps[0] = tmps[1];
     834         2148 :         if(tmp.len <= tmps[3].len){
     835         1289 :             tmps[1] = tmp;
     836              :         }else{
     837          859 :             tmps[1] = tmps[3];
     838          859 :             tmps[3] = tmp;
     839              :         }
     840              :         //now, tmps[1] and tmps[3] hold the nontrivial factors, with tmps[1] having smallest degree
     841              :         //fprintf(stderr, "\e[1;33m got factors (");
     842              :         //nut_Poly_fprint(stderr, tmps + 1, "x", " + ", " - ", "**", 0);
     843              :         //fprintf(stderr, ")(");
     844              :         //nut_Poly_fprint(stderr, tmps + 3, "x", " + ", " - ", "**", 0);
     845              :         //fprintf(stderr, ")\e[0m\n");
     846         2148 :         bool have_small_factor = tmps[1].len <= 3;
     847         2148 :         if(!have_small_factor){
     848          417 :             if(!nut_Poly_init(&tmp, tmps[3].len) || !nut_Poly_copy(&tmp, tmps + 3)){
     849            0 :                 return false;
     850              :             }
     851              :         }
     852         2148 :         if(!roots_polyn_modn_rec(roots, n, tmps)){
     853              :             return false;
     854              :         }
     855         2148 :         if(!have_small_factor){
     856          417 :             if(!nut_Poly_copy(tmps + 3, &tmp)){
     857              :                 return false;
     858              :             }
     859          417 :             nut_Poly_destroy(&tmp);
     860              :         }
     861         2148 :         tmp = tmps[3];
     862         2148 :         tmps[3] = tmps[1];
     863         2148 :         tmps[1] = tmp;
     864              :     }
     865              : }
     866              : 
     867         5000 : bool nut_Poly_roots_modn(const nut_Poly *restrict f, int64_t n, nut_Roots *restrict roots, nut_Poly tmps[restrict static 6]){
     868              :     //tmps: gt, ft, xt, qt, st, rt
     869         5000 :     if(!nut_Poly_factors_d_modn(tmps + 1, f, 1, n, tmps + 2)){
     870              :         return false;
     871              :     }
     872         5000 :     roots->len = 0;
     873         5000 :     if(tmps[1].len == 1){
     874              :         return true;
     875              :     }
     876          523 :     return nut_Roots_ensure_cap(roots, tmps[1].len - 1) && roots_polyn_modn_rec(roots, n, tmps);
     877              : }
     878              : 
     879         5000 : bool nut_Poly_roots_modn_tmptmp(const nut_Poly *restrict f, int64_t n, nut_Roots *restrict roots){
     880         5000 :     nut_Poly tmps[6] = {};
     881         5000 :     bool status = true;
     882        35000 :     for(uint64_t i = 0; status && i < 6; ++i){
     883        30000 :         status = nut_Poly_init(tmps + i, f->len);
     884              :     }
     885         5000 :     if(status){
     886         5000 :         status = nut_Poly_roots_modn(f, n, roots, tmps);
     887              :     }
     888        35000 :     for(uint64_t i = 0; i < 6; ++i){
     889        30000 :         if(tmps[i].cap){
     890        30000 :             nut_Poly_destroy(tmps + i);
     891              :         }
     892              :     }
     893         5000 :     return status;
     894              : }
     895              : 
        

Generated by: LCOV version 2.0-1