About Docs Source
LCOV - code coverage report
Current view: top level - src/lib/nut - modular_math.c (source / functions) Coverage Total Hit
Test: unnamed Lines: 76.0 % 233 177
Test Date: 2025-10-22 01:14:28 Functions: 59.4 % 32 19

            Line data    Source code
       1              : #include <stddef.h>
       2              : #if __has_include(<linux/version.h>)
       3              : #include <linux/version.h>
       4              : #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,17,0)
       5              : #if __has_include(<gnu/libc-version.h>)
       6              : #include <gnu/libc-version.h>
       7              : #include <sys/random.h>
       8              : #if (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25) || __GLIBC__ > 2
       9              : #define _NUT_RAND_USE_GETRANDOM
      10              : #else
      11              : #include <sys/syscall.h>
      12              : #include <unistd.h>
      13              : #define _NUT_RAND_USE_SYSCALL
      14              : #endif
      15              : #else
      16              : #include <stdio.h>
      17              : #define _NUT_RAND_USE_URANDOM
      18              : #endif
      19              : #endif
      20              : #elifdef __MINGW32__
      21              : #define _CRT_RAND_S
      22              : #include <stdlib.h>
      23              : #define _NUT_RAND_USE_RAND_S
      24              : #endif
      25              : #include <string.h>
      26              : #include <nut/debug.h>
      27              : #include <nut/modular_math.h>
      28              : 
      29    133369638 : uint64_t nut_u64_pow(uint64_t b, uint64_t e){
      30    133369638 :     if(!e){
      31              :         return 1;
      32              :     }
      33              :     uint64_t r = 1;
      34    723877974 :     while(1){
      35    426347146 :         if(e&1){
      36    305971482 :             r = r*b;
      37              :         }
      38    426347146 :         if(!(e >>= 1)){
      39              :             return r;
      40              :         }
      41    297530828 :         b *= b;
      42              :     }
      43              : }
      44              : 
      45    592251918 : uint128_t nut_u128_pow(uint128_t b, uint64_t e){
      46    592251918 :     if(!e){
      47              :         return 1;
      48              :     }
      49              :     uint128_t r = 1;
      50   3496096260 :     while(1){
      51   2044174089 :         if(e&1){
      52   1142281367 :             r = r*b;
      53              :         }
      54   2044174089 :         if(!(e >>= 1)){
      55              :             return r;
      56              :         }
      57   1451922171 :         b *= b;
      58              :     }
      59              : }
      60              : 
      61    134302355 : uint64_t nut_u64_powmod(uint64_t b, uint64_t e, uint64_t n){
      62    134302355 :     if(!e){
      63              :         return 1;
      64              :     }
      65    134302355 :     uint64_t r = 1;
      66    134302355 :     b %= n;
      67   2227680669 :     while(1){
      68   1180991512 :         if(e&1){
      69    659613646 :             r = (uint128_t)r*b%n;
      70              :         }
      71   1180991512 :         if(!(e >>= 1)){
      72              :             return r;
      73              :         }
      74   1046689157 :         b = (uint128_t)b*b%n;
      75              :     }
      76              : }
      77              : 
      78      8573706 : uint64_t nut_u64_binom(uint64_t n, uint64_t k){
      79      8573706 :     uint64_t res = 1;
      80      8573706 :     if(n - k < k){
      81              :         k = n - k;
      82              :     }
      83     18523448 :     for(uint64_t i = 0; i < k; ++i){
      84      9949742 :         res = res*(n - i)/(1 + i);
      85              :     }
      86      8573706 :     return res;
      87              : }
      88              : 
      89         1891 : uint64_t nut_u64_binom_next(uint64_t n, uint64_t k, uint64_t prev){
      90         1891 :     return prev*(n - k + 1)/k;
      91              : }
      92              : 
      93              : #if defined(_NUT_RAND_USE_GETRANDOM) || defined(_NUT_RAND_USE_SYSCALL) || defined(_NUT_RAND_USE_URANDOM)
      94              : #ifdef _NUT_RAND_USE_URANDOM
      95              : __thread FILE *nut_urandom = NULL;
      96              : #endif
      97      1519156 : uint64_t nut_u64_rand(uint64_t a, uint64_t b){
      98              : #ifdef _NUT_RAND_USE_URANDOM
      99              :     if(!nut_urandom){
     100              :         nut_urandom = fopen("/dev/urandom", "rb");
     101              :     }
     102              : #endif
     103      1519156 :     uint64_t l = b - a, r = 0, bytes = (71 - __builtin_clzll(l))/8;
     104      1519156 :     uint64_t ub;
     105      1519156 :     if(bytes == 8){
     106         1000 :         ub = ~0ull%l + 1;
     107      1519156 :         ub = (ub == l) ? 0 : -ub;
     108              :     }else{
     109      1518156 :         ub = 1ull << (bytes*8);
     110      1518156 :         ub -= ub%l;
     111              :     }
     112      1541571 :     do{
     113              : #pragma GCC diagnostic push
     114              : #pragma GCC diagnostic ignored "-Wunused-result"
     115              : #ifdef _NUT_RAND_USE_GETRANDOM
     116      1541571 :         getrandom(&r, bytes, 0);
     117              : #elifdef _NUT_RAND_USE_SYSCALL
     118              :         syscall(SYS_getrandom, &r, bytes, 0);
     119              : #else
     120              :         fread(&r, 8, 1, nut_urandom);
     121              : #endif
     122              : #pragma GCC diagnostic pop
     123      1541571 :     }while(ub && r >= ub);
     124      1519156 :     return r%l + a;
     125              : }
     126              : #endif
     127              : 
     128              : #if defined(_NUT_RAND_USE_RAND_S)
     129              : uint64_t nut_u64_rand(uint64_t a, uint64_t b){
     130              :     uint64_t l = b - a, r = 0;
     131              :     uint64_t ub;
     132              :     if(l > (1ull << 32)){
     133              :         ub = ~0ull%l + 1;
     134              :         ub = (ub == l) ? 0 : -ub;
     135              :         do{
     136              :             rand_s((uint32_t*)&r);
     137              :             rand_s((uint32_t*)&r + 1);
     138              :         }while(ub && r >= ub);
     139              :     }else{
     140              :         ub = 1ull << 32;
     141              :         ub -= ub%l;
     142              :         do{
     143              :             rand_s((uint32_t*)&r);
     144              :         }while(ub && r >= ub);
     145              :     }
     146              :     return r%l + a;
     147              : }
     148              : #endif
     149              : 
     150       114111 : uint64_t nut_u64_prand(uint64_t a, uint64_t b){
     151       114111 :     return nut_u64_rand(a, b);//TODO: test if this is a bottleneck, we don't need calls to this function to be secure
     152              : }
     153              : 
     154      5098048 : int64_t nut_i64_egcd(int64_t a, int64_t b, int64_t *restrict _t, int64_t *restrict _s){
     155      5098048 :     int64_t r0 = b, r1 = a;
     156      5098048 :     int64_t s0 = 1, s1 = 0;
     157      5098048 :     int64_t t0 = 0, t1 = 1;
     158     25325014 :     while(r1){
     159     20226966 :         int64_t q = r0/r1, t;
     160     20226966 :         t = r1;
     161     20226966 :         r1 = r0 - q*r1;
     162     20226966 :         r0 = t;
     163     20226966 :         t = s1;
     164     20226966 :         s1 = s0 - q*s1;
     165     20226966 :         s0 = t;
     166     20226966 :         t = t1;
     167     20226966 :         t1 = t0 - q*t1;
     168     20226966 :         t0 = t;
     169              :     }
     170      5098048 :     if(_t){
     171       805221 :         *_t = t0;
     172              :     }
     173      5098048 :     if(_s){
     174           15 :         *_s = s0;
     175              :     }
     176      5098048 :     return r0;
     177              : }
     178              : 
     179           17 : int64_t nut_i64_modinv(int64_t a, int64_t b){
     180           17 :     int64_t ainv;
     181           17 :     nut_i64_egcd(a, b, &ainv, NULL);
     182           17 :     return ainv < 0 ? ainv + b : ainv;
     183              : }
     184              : 
     185              : // This uses a hensel/newton like iterative algorithm, described here
     186              : // https://crypto.stackexchange.com/a/47496
     187              : // Basically, we use a lookup table to get the inverse mod 2**8, and then
     188              : // use the fact that ax = 1 mod 2**k --> ax(2-ax) = 1 mod 2**(2k) to lift the inverse to
     189              : // mod 2**16, mode 2**32, etc as needed.  This does cap out at 2**64.
     190              : NUT_ATTR_NO_SAN("unsigned-shift-base")
     191              : NUT_ATTR_NO_SAN("unsigned-integer-overflow")
     192      1985500 : uint64_t nut_u64_modinv_2t(uint64_t a, uint64_t t){
     193      1985500 :     static const uint8_t modinv_256_tbl[] = {
     194              :         1, 171, 205, 183, 57, 163, 197, 239, 241, 27, 61, 167, 41, 19, 53, 223,
     195              :         225, 139, 173, 151, 25, 131, 165, 207, 209, 251, 29, 135, 9, 243, 21, 191,
     196              :         193, 107, 141, 119, 249, 99, 133, 175, 177, 219, 253, 103, 233, 211, 245, 159,
     197              :         161, 75, 109, 87, 217, 67, 101, 143, 145, 187, 221, 71, 201, 179, 213, 127,
     198              :         129, 43, 77, 55, 185, 35, 69, 111, 113, 155, 189, 39, 169, 147, 181, 95,
     199              :         97, 11, 45, 23, 153, 3, 37, 79, 81, 123, 157, 7, 137, 115, 149, 63,
     200              :         65, 235, 13, 247, 121, 227, 5, 47, 49, 91, 125, 231, 105, 83, 117, 31,
     201              :         33, 203, 237, 215, 89, 195, 229, 15, 17, 59, 93, 199, 73, 51, 85, 255
     202              :     };
     203      1985500 :     assert(t <= 64 && a&1);
     204      1985500 :     uint64_t t1 = 8;
     205      1985500 :     uint64_t x = modinv_256_tbl[(a >> 1)&0x7F];
     206      6269500 :     while(t1 < t){
     207      4284000 :         t1 *= 2;
     208      4284000 :         x = (x*(2 - a*x));
     209      4284000 :         if(t1 != 64){
     210      3276000 :             x &= (1ull << t1) - 1;
     211              :         }
     212              :     }
     213      1985500 :     return t < 64 ? x&((1ull << t) - 1) : x;
     214              : }
     215              : 
     216    738231103 : int64_t nut_i64_mod(int64_t a, int64_t n){
     217    738231103 :     int64_t r = a%n;
     218    738231103 :     if(r < 0){
     219      1857924 :         r += n;
     220              :     }
     221    738231103 :     return r;
     222              : }
     223              : 
     224            0 : int64_t nut_i64_crt(int64_t a, int64_t p, int64_t b, int64_t q){
     225            0 :     int64_t x, y;
     226            0 :     nut_i64_egcd(p, q, &x, &y);
     227            0 :     return nut_i64_mod(b*p%(p*q)*x + a*q%(p*q)*y, p*q);
     228              : }
     229              : 
     230           15 : int128_t nut_i128_crt(int64_t a, int64_t p, int64_t b, int64_t q){
     231           15 :     int64_t _x, _y;
     232           15 :     nut_i64_egcd(p, q, &_x, &_y);
     233           15 :     int128_t x = _x, y = _y;
     234           15 :     x = b*p%(p*q)*x + a*q%(p*q)*y;
     235           15 :     x %= p*q;
     236           15 :     return x < 0 ? x + p*q : x;
     237              : }
     238              : 
     239      4226730 : int64_t nut_i64_lcm(int64_t a, int64_t b){
     240      4226730 :     return a*b/nut_i64_egcd(a, b, NULL, NULL);
     241              : }
     242              : 
     243              : NUT_ATTR_NO_SAN("unsigned-shift-base")
     244              : NUT_ATTR_NO_SAN("unsigned-integer-overflow")
     245      1921500 : uint64_t nut_u64_binom_next_mod_2t(uint64_t n, uint64_t k, uint64_t t, uint64_t *restrict v2, uint64_t *restrict p2){
     246      1921500 :     uint64_t num = n - k + 1;
     247      1921500 :     uint64_t num_v2 = __builtin_ctz(num);
     248      1921500 :     uint64_t denom_v2 = __builtin_ctz(k);
     249      1921500 :     *v2 = *v2 + num_v2 - denom_v2;
     250      1921500 :     uint64_t denom_pinv = nut_u64_modinv_2t(k >> denom_v2, t);
     251      1921500 :     uint64_t mask = t < 64 ? ((1ull << t) - 1) : ~0ull;
     252      1921500 :     *p2 = (*p2 * (num >> num_v2) * denom_pinv) & mask;
     253      1921500 :     return (*p2 << *v2) & mask;
     254              : }
     255              : 
     256      1244317 : int64_t nut_i64_jacobi(int64_t n, int64_t k){
     257      1244317 :     if(n%k == 0){
     258              :         return 0;
     259              :     }
     260              :     int64_t j = 1;
     261     10019139 :     while(1){
     262      5631728 :         int64_t s = __builtin_ctzll(n);
     263      5631728 :         int64_t q = n >> s;
     264      5631728 :         if((s&1) && ((k&7) == 3 || ((k&7) == 5))){
     265      1371410 :             j = -j;
     266              :         }
     267      5631728 :         if(q == 1){
     268              :             return j;
     269      4387411 :         }else if(q == k - 1){
     270            0 :             return (k&3)==1 ? j : -j;
     271      4387411 :         }else if((q&2) && (k&2)){
     272      1305965 :             j = -j;
     273              :         }
     274      4387411 :         n = k%q;
     275      4387411 :         k = q;
     276              :     }
     277              : }
     278              : 
     279            0 : uint64_t *nut_u64_make_jacobi_tbl(uint64_t p, int64_t partial_sums[restrict static p]){
     280            0 :     uint64_t *is_qr = calloc((p + 63)/64, sizeof(uint64_t));
     281            0 :     if(!is_qr){
     282              :         return NULL;
     283              :     }
     284            0 :     for(uint64_t n = 1, nn = 1; n <= p/2; ++n){
     285            0 :         is_qr[nn/64] |= 1ull << (nn%64);
     286            0 :         nn += 2*n + 1;
     287            0 :         if(nn >= p){
     288            0 :             nn -= p;
     289              :         }
     290              :     }
     291            0 :     partial_sums[0] = 0;
     292            0 :     int64_t acc = 0;
     293            0 :     for(uint64_t n = 1; n < p; ++n){
     294            0 :         bool qr = is_qr[n/64] & (1ull << (n%64));
     295            0 :         acc += qr ? 1 : -1;
     296            0 :         partial_sums[n] = acc;
     297              :     }
     298              :     return is_qr;
     299              : }
     300              : 
     301            0 : int64_t nut_u64_jacobi_tbl_get(uint64_t n, uint64_t p, const uint64_t is_qr[restrict static (p + 63)/64]){
     302            0 :     uint64_t r = n%p;
     303            0 :     if(!r){
     304              :         return 0;
     305              :     }
     306            0 :     return (is_qr[r/64] & (1ull << (r%64))) ? 1 : -1;
     307              : }
     308              : 
     309       122284 : int64_t nut_i64_rand_nr(int64_t p){
     310       244303 :     while(1){
     311       244303 :         int64_t z = nut_u64_rand(2, p);
     312       244303 :         if(nut_i64_jacobi(z, p) == -1){
     313       122284 :             return z;
     314              :         }
     315              :     }
     316              : }
     317              : 
     318       122284 : int64_t nut_i64_sqrt_shanks(int64_t n, int64_t p){
     319       122284 :     int64_t s = __builtin_ctzll(p-1);
     320       122284 :     int64_t q = p >> s;//p-1 = q*2^s
     321       122284 :     int64_t z = nut_i64_rand_nr(p);
     322              :     //printf("trying \"nonresidue\" %"PRIu64"\n", z);
     323       122284 :     int64_t m = s;
     324       122284 :     int64_t c = nut_u64_powmod(z, q, p);
     325       122284 :     int64_t t = nut_u64_powmod(n, q, p);
     326       122284 :     int64_t r = nut_u64_powmod(n, (q + 1) >> 1, p);
     327       304835 :     while(t != 1){
     328       182551 :         int64_t i = 1;
     329       424352 :         for(int64_t s = (int128_t)t*t%p; s != 1; s = (int128_t)s*s%p, ++i);
     330              :         int64_t b = c;
     331       264224 :         for(int64_t j = 0; j < m - i - 1; ++j){
     332        81673 :             b = (int128_t)b*b%p;
     333              :         }
     334       182551 :         m = i;
     335       182551 :         c = (int128_t)b*b%p;
     336       182551 :         t = (int128_t)t*c%p;
     337       182551 :         r = (int128_t)r*b%p;
     338              :     }
     339       122284 :     return r;
     340              : }
     341              : 
     342            6 : int64_t nut_i64_sqrt_cipolla(int64_t n, int64_t p){
     343           14 :     int64_t a, w;
     344           14 :     do{
     345           14 :         a = nut_u64_rand(2, p);
     346           14 :         w = nut_i64_mod((int128_t)a*a%p - n, p);
     347           14 :     }while(nut_i64_jacobi(w, p) != -1);
     348            6 :     int64_t u_s = a, w_s = 1;
     349            6 :     int64_t u_r = 1, w_r = 0;
     350          165 :     for(int64_t k = (p + 1) >> 1; k; k >>= 1){
     351          159 :         if(k&1){
     352           50 :             int64_t _w_r = (int128_t)u_r*w_s%p;
     353           50 :             _w_r = (_w_r + (int128_t)w_r*u_s)%p;
     354           50 :             u_r = (int128_t)u_r*u_s%p;
     355           50 :             w_r = (int128_t)w_r*w_s%p;
     356           50 :             w_r = (int128_t)w_r*w%p;
     357           50 :             u_r = ((int128_t)u_r + w_r)%p;
     358           50 :             w_r = _w_r;
     359              :         }
     360          159 :         int64_t _w_s = (int128_t)2*u_s*w_s%p;
     361          159 :         u_s = (int128_t)u_s*u_s%p;
     362          159 :         w_s = (int128_t)w_s*w_s%p;
     363          159 :         w_s = (int128_t)w*w_s%p;
     364          159 :         u_s = ((int128_t)u_s + w_s)%p;
     365          159 :         w_s = _w_s;
     366              :     }
     367            6 :     return u_r;
     368              : }
     369              : 
     370       490436 : int64_t nut_i64_sqrt_mod(int64_t n, int64_t p){
     371       490436 :     int64_t r;
     372       490436 :     if((p&3) == 3){
     373       245512 :         r = nut_u64_powmod(n, (p + 1) >> 2, p);
     374       244924 :     }else if((p&7) == 5){
     375       122634 :         r = nut_u64_powmod(n, (p + 3) >> 3, p);
     376       122634 :         if(r*r%p != n){
     377        61485 :             r = (int128_t)r*nut_u64_powmod(2, (p - 1) >> 2, p)%p;
     378              :         }//can add 9 mod 16 case
     379              :     }else{
     380       122290 :         int64_t m = 64 - __builtin_clzll(p);
     381       122290 :         int64_t s = __builtin_ctzll(p - 1);
     382       122290 :         if(8*m + 20 >= s*(s - 1)){
     383       122284 :             r = nut_i64_sqrt_shanks(n, p);
     384              :         }else{
     385            6 :             r = nut_i64_sqrt_cipolla(n, p);
     386              :         }
     387              :     }
     388       490436 :     return r;
     389              : }
     390              : 
     391              : 
     392              : /// See https://arxiv.org/pdf/1902.01961.pdf
     393            0 : uint64_t nut_i32_fastmod_init(uint32_t pd){
     394            0 :     return ~0ull/pd + 1 + (__builtin_popcount(pd) == 1);
     395              : }
     396              : 
     397            0 : uint64_t nut_u32_fastmod_init(uint32_t d){
     398            0 :     return ~0ull/d + 1;
     399              : }
     400              : 
     401            0 : uint128_t nut_i64_fastmod_init(uint64_t pd){
     402            0 :     return ~(uint128_t)0/pd + 1 + (__builtin_popcount(pd) == 1);
     403              : }
     404              : 
     405            0 : uint128_t nut_u64_fastmod_init(uint64_t d){
     406            0 :     return ~(uint128_t)0/d + 1;
     407              : }
     408              : 
     409              : 
     410            0 : int32_t nut_i32_fastmod_trunc(int32_t n, uint32_t pd, uint64_t c){
     411            0 :     uint64_t cn = c*n;
     412            0 :     int32_t cnd = ((uint128_t)cn*pd) >> 64;
     413            0 :     return cnd - ((pd - 1) & (n >> 31));
     414              : }
     415              : 
     416            0 : int32_t nut_i32_fastmod_floor(int32_t n, uint32_t pd, uint64_t c){
     417            0 :     uint64_t cn = c*n;
     418            0 :     int32_t cnd = ((uint128_t)cn*pd) >> 64;
     419            0 :     return cnd - ((pd - 1) && (n >> 31));
     420              : }
     421              : 
     422            0 : uint32_t nut_u32_fastmod(uint32_t n, uint32_t d, uint64_t c){
     423            0 :     uint64_t cn = c*n;
     424            0 :     return ((uint128_t)cn*d) >> 64;
     425              : }
     426              : 
     427            0 : int64_t nut_i64_fastmod_trunc(int64_t n, uint64_t pd, uint128_t c){
     428            0 :     uint128_t cn = c*n;
     429            0 :     uint64_t cn_hi = cn >> 64;
     430            0 :     int64_t cnd_hi = ((uint128_t)cn_hi*pd) >> 64;
     431              :     // c*n*pd >> 128 does not actually depend on the low bits of cn
     432            0 :     return cnd_hi - ((pd - 1) & (n >> 63));
     433              : }
     434              : 
     435            0 : int64_t nut_i64_fastmod_floor(int64_t n, uint64_t pd, uint128_t c){
     436            0 :     uint128_t cn = c*n;
     437            0 :     uint64_t cn_hi = cn >> 64;
     438            0 :     int64_t cnd_hi = ((uint128_t)cn_hi*pd) >> 64;
     439              :     // c*n*pd >> 128 does not actually depend on the low bits of cn
     440            0 :     return cnd_hi - ((pd - 1) && (n >> 63));
     441              : }
     442              : 
     443            0 : uint64_t nut_u64_fastmod(uint64_t n, uint64_t d, uint128_t c){
     444            0 :     uint128_t cn = c*n;
     445            0 :     uint64_t cn_hi = cn >> 64;
     446            0 :     return ((uint128_t)cn_hi*d) >> 64;
     447              : }
     448              : 
     449              : 
        

Generated by: LCOV version 2.0-1