About Docs Source
LCOV - code coverage report
Current view: top level - src/lib/nut - dirichlet_powerful.c (source / functions) Coverage Total Hit
Test: unnamed Lines: 76.0 % 154 117
Test Date: 2025-10-22 01:14:28 Functions: 90.9 % 11 10

            Line data    Source code
       1              : #include <stddef.h>
       2              : #include <stdint.h>
       3              : #include <stdlib.h>
       4              : #include <string.h>
       5              : #include <stdio.h>
       6              : 
       7              : #include <nut/debug.h>
       8              : #include <nut/modular_math.h>
       9              : #include <nut/factorization.h>
      10              : #include <nut/dirichlet.h>
      11              : #include <nut/dirichlet_powerful.h>
      12              : #include <nut/sieves.h>
      13              : 
      14           21 : static bool PfIt_init(nut_PfIt *self, uint64_t max, uint64_t modulus, uint64_t small_primes){
      15           21 :     if(!max){
      16            0 :         *self = (nut_PfIt){};
      17            0 :         return true;
      18              :     }
      19           21 :     self->small_primes = small_primes;
      20           21 :     self->cap = 63 - __builtin_clzll(max);// floor(log_2(max))
      21           21 :     self->rt_max = nut_u64_nth_root(max, 2);
      22           21 :     if(!(self->entries = malloc(self->cap*sizeof(*self->entries)))){
      23              :         return false;
      24           21 :     }else if(!(self->primes = nut_sieve_primes(self->rt_max, &self->num_primes))){
      25            0 :         free(self->entries);
      26            0 :         self->entries = NULL;
      27            0 :         return false;
      28              :     }
      29           21 :     self->entries[0] = (nut_PfStackEnt){.n=1, .hn=1, .i=0};
      30           21 :     self->len = 1;
      31           21 :     self->max = max;
      32           21 :     self->modulus = modulus;
      33           21 :     return true;
      34              : }
      35              : 
      36            2 : bool nut_PfIt_init_fn(nut_PfIt *self, uint64_t max, uint64_t modulus, uint64_t small_primes, int64_t (*h_fn)(uint64_t p, uint64_t pp, uint64_t e, uint64_t m)){
      37            2 :     if(!PfIt_init(self, max, modulus, small_primes)){
      38              :         return false;
      39              :     }
      40            2 :     self->h_kind = NUT_DIRI_H_FN;
      41            2 :     self->h_fn = h_fn;
      42            2 :     return true;
      43              : }
      44              : 
      45           18 : bool nut_PfIt_init_hvals(nut_PfIt *restrict self, uint64_t max, uint64_t modulus, uint64_t small_primes, const int64_t *restrict h_vals){
      46           18 :     if(!PfIt_init(self, max, modulus, small_primes)){
      47              :         return false;
      48              :     }
      49           18 :     self->h_kind = NUT_DIRI_H_VALS;
      50           18 :     self->h_vals = h_vals;
      51           18 :     return true;
      52              : }
      53              : 
      54            1 : bool nut_PfIt_init_hseqs(nut_PfIt *restrict self, uint64_t max, uint64_t modulus, uint64_t small_primes,
      55              :     int64_t (*f_fn)(uint64_t p, uint64_t pp, uint64_t e, uint64_t m), int64_t (*g_fn)(uint64_t p, uint64_t pp, uint64_t e, uint64_t m)
      56              : ){
      57            1 :     if(!PfIt_init(self, max, modulus, small_primes)){
      58              :         return false;
      59              :     }
      60            1 :     uint64_t values_cap = 3*self->num_primes;
      61            1 :     self->h_kind = NUT_DIRI_H_SEQS;
      62            1 :     self->h_seqs.offsets = malloc(self->num_primes*sizeof(uint64_t));
      63            1 :     self->h_seqs.values = malloc(values_cap*sizeof(int64_t));
      64            1 :     if(!self->h_seqs.offsets || !self->h_seqs.values){
      65            0 :         nut_PfIt_destroy(self);
      66            0 :         return false;
      67              :     }
      68            1 :     uint64_t base_offset = 0;
      69            1 :     uint64_t curr_max_pow = 63 - __builtin_clzll(max);
      70            2 :     int64_t *f_series [[gnu::cleanup(cleanup_free)]] = malloc((curr_max_pow + 1)*sizeof(int64_t));
      71            2 :     int64_t *g_series [[gnu::cleanup(cleanup_free)]] = malloc((curr_max_pow + 1)*sizeof(int64_t));
      72            2 :     int64_t *h_series [[gnu::cleanup(cleanup_free)]] = malloc((curr_max_pow + 1)*sizeof(int64_t));
      73            1 :     if(!f_series || !g_series || !h_series){
      74            0 :         nut_PfIt_destroy(self);
      75            0 :         return false;
      76              :     }
      77            1 :     f_series[0] = g_series[0] = 1;
      78            1 :     uint64_t curr_max_prime = 2;
      79            3 :     for(uint64_t i = 0; i < self->num_primes; ++i){
      80            2 :         uint64_t p = self->primes[i];
      81              :         // ensure that p^curr_max_pow is still <= max
      82            3 :         while(curr_max_pow > 2 && p > curr_max_prime){
      83              :             // this must be a while loop because for very small primes like 2 it's possible for the max power to drop more than 1
      84              :             // also when the max power is 2 we don't need to check anymore since we've only gathered the primes up to sqrt(max)
      85            1 :             --curr_max_pow;
      86            1 :             curr_max_prime = nut_u64_nth_root(max, curr_max_pow);
      87              :         }
      88              :         // ensure that the values table is large enough
      89            2 :         uint64_t min_pow = i >= self->small_primes ? 2 : 1;
      90            2 :         if(base_offset + curr_max_pow - min_pow >= values_cap){
      91            0 :             uint64_t new_cap = base_offset + 2*(self->num_primes - i);
      92            0 :             int64_t *tmp = realloc(self->h_seqs.values, new_cap*sizeof(int64_t));
      93            0 :             if(!tmp){
      94            0 :                 nut_PfIt_destroy(self);
      95            0 :                 return false;
      96              :             }
      97            0 :             self->h_seqs.values = tmp;
      98            0 :             values_cap = new_cap;
      99              :         }
     100            7 :         for(uint64_t pp = 1, e = 1; e <= curr_max_pow; ++e){
     101            5 :             pp *= p;
     102            5 :             f_series[e] = f_fn(p, pp, e, modulus);
     103            5 :             g_series[e] = g_fn(p, pp, e, modulus);
     104              :         }
     105            2 :         nut_series_div(curr_max_pow + 1, modulus, h_series, f_series, g_series);
     106            2 :         self->h_seqs.offsets[i] = base_offset;
     107            2 :         memcpy(self->h_seqs.values + base_offset, h_series + min_pow, (curr_max_pow + 1 - min_pow)*sizeof(int64_t));
     108            2 :         base_offset += curr_max_pow + 1 - min_pow;
     109              :     }
     110              :     return true;
     111              : }
     112              : 
     113           21 : void nut_PfIt_destroy(nut_PfIt *self){
     114           21 :     free(self->entries);
     115           21 :     free(self->primes);
     116           21 :     if(self->h_kind == NUT_DIRI_H_SEQS){
     117            1 :         free(self->h_seqs.offsets);
     118            1 :         free(self->h_seqs.values);
     119              :     }
     120           21 :     *self = (nut_PfIt){};
     121           21 : }
     122              : 
     123          656 : bool nut_PfStack_push(nut_PfIt *restrict self, const nut_PfStackEnt *restrict ent){
     124          656 :     if(self->len == self->cap){
     125            0 :         uint64_t new_cap = (self->cap << 1) ?: 8;// it shouldn't be possible to get here with cap = 0, but work around it anyway
     126            0 :         void *tmp = realloc(self->entries, new_cap*sizeof(*self->entries));
     127            0 :         if(!tmp){
     128              :             return false;
     129              :         }
     130            0 :         self->cap = new_cap;
     131            0 :         self->entries = tmp;
     132              :     }
     133          656 :     memcpy(self->entries + self->len++, ent, sizeof(*self->entries));
     134          656 :     return true;
     135              : }
     136              : 
     137          677 : bool nut_PfStack_pop(nut_PfIt *restrict self, nut_PfStackEnt *restrict out){
     138          677 :     if(!self->len){
     139              :         return false;
     140              :     }
     141          677 :     memcpy(out, self->entries + --self->len, sizeof(*self->entries));
     142          677 :     return true;
     143              : }
     144              : 
     145          466 : bool nut_PfIt_next(nut_PfIt *restrict self, nut_PfStackEnt *restrict out){
     146              :     // Yields (n, h(n) mod m) where n are the O(sqrt x) powerful numbers
     147              :     // up to x, and h is any multiplicative function.
     148          698 :     while(self->len){
     149          677 :         nut_PfStack_pop(self, out);
     150          677 :         if(out->i >= self->num_primes){
     151              :             return true;
     152              :         }
     153          635 :         uint64_t p = self->primes[out->i];
     154          635 :         uint64_t min_pow = out->i >= self->small_primes ? 2 : 1;
     155          635 :         if(min_pow == 2){
     156          635 :             if(p*p > self->max/out->n){
     157              :                 return true;
     158              :             }
     159              :         }else{
     160            0 :             if(p > self->max/out->n){
     161              :                 return true;
     162              :             }
     163              :         }
     164          232 :         if(!nut_PfStack_push(self, &(nut_PfStackEnt){.n=out->n, .hn=out->hn, .i= out->i + 1})){
     165            0 :             return false;
     166              :         }
     167          888 :         for(uint64_t pp = min_pow == 2 ? p : 1, e = min_pow; !__builtin_mul_overflow(pp, p, &pp) && pp <= self->max/out->n; ++e){
     168          424 :             int64_t v = out->hn;
     169          424 :             if(self->h_kind == NUT_DIRI_H_VALS){
     170          396 :                 v *= self->h_vals[e];
     171           28 :             }else if(self->h_kind == NUT_DIRI_H_FN){
     172           25 :                 v *= self->h_fn(p, pp, e, self->modulus);
     173            3 :             }else if(self->h_kind == NUT_DIRI_H_SEQS){
     174            3 :                 v *= self->h_seqs.values[self->h_seqs.offsets[out->i] + e - min_pow];
     175              :             }else{
     176            0 :                 return false;
     177              :             }
     178          424 :             if(self->modulus){
     179          396 :                 v = nut_i64_mod(v, self->modulus);
     180              :             }
     181          424 :             if(!nut_PfStack_push(self, &(nut_PfStackEnt){.n = out->n*pp, .hn = v, .i = out->i + 1})){
     182            0 :                 return false;
     183              :             }
     184              :         }
     185              :     }
     186              :     return false;
     187              : }
     188              : 
     189           19 : bool nut_Diri_sum_adjusted(int64_t *restrict out, const nut_Diri *restrict g_tbl, nut_PfIt *pf_it){
     190           38 :     int64_t *g_dense [[gnu::cleanup(cleanup_free)]] = malloc((g_tbl->y + 1)*sizeof(int64_t));
     191           19 :     if(!g_dense || pf_it->modulus > INT64_MAX){
     192            0 :         return false;
     193              :     }
     194           19 :     g_dense[0] = 0;
     195           19 :     int64_t m = pf_it->modulus;
     196          285 :     for(int64_t i = 1; i <= g_tbl->y; ++i){
     197          266 :         int64_t term = g_dense[i - 1] + nut_Diri_get_dense(g_tbl, i);
     198          266 :         g_dense[i] = m ? nut_i64_mod(term, m) : term;
     199              :     }
     200           19 :     int64_t res = 0;
     201           19 :     nut_PfStackEnt ent;
     202          456 :     while(nut_PfIt_next(pf_it, &ent)){
     203          437 :         int64_t Gn;
     204          437 :         if(ent.n >= (uint64_t)g_tbl->yinv){
     205          361 :             Gn = g_dense[g_tbl->x/ent.n];
     206              :         }else{
     207           76 :             Gn = nut_Diri_get_sparse(g_tbl, ent.n);
     208              :         }
     209          437 :         res += ent.hn*Gn;
     210          437 :         if(m){
     211          414 :             res = nut_i64_mod(res, m);
     212              :         }
     213              :     }
     214           19 :     *out = res;
     215           19 :     return true;
     216              : }
     217              : 
     218            0 : bool nut_Diri_sum_u_adjusted(int64_t *restrict out, nut_PfIt *pf_it){
     219            0 :     int64_t m = pf_it->modulus;
     220            0 :     int64_t res = 0;
     221            0 :     uint64_t max = pf_it->max;
     222            0 :     nut_PfStackEnt ent;
     223            0 :     while(nut_PfIt_next(pf_it, &ent)){
     224            0 :         res += ent.hn*(max/ent.n);
     225            0 :         if(m){
     226            0 :             res = nut_i64_mod(res, m);
     227              :         }
     228              :     }
     229            0 :     *out = res;
     230            0 :     return true;
     231              : }
     232              : 
     233           20 : void nut_series_div(uint64_t n, int64_t m, int64_t h[restrict static n], int64_t f[restrict static n], int64_t g[restrict static n]){
     234          189 :     for(uint64_t e = 0; e < n; ++e){
     235          169 :         int64_t term = f[e];
     236          826 :         for(uint64_t k = 1; k <= e; ++k){
     237          657 :             term -= g[k]*h[e - k];
     238          657 :             if(m){
     239          648 :                 term %= m;
     240              :             }
     241              :         }
     242          169 :         h[e] = (m && term < 0) ? m + term : term;
     243              :     }
     244           20 : }
     245              : 
        

Generated by: LCOV version 2.0-1