About Docs Source
LCOV - code coverage report
Current view: top level - src/lib/nut - matrix.c (source / functions) Coverage Total Hit
Test: unnamed Lines: 74.3 % 171 127
Test Date: 2025-10-22 01:14:28 Functions: 77.3 % 22 17

            Line data    Source code
       1              : #include "nut/modular_math.h"
       2              : #include <nut/matrix.h>
       3              : #include <stdint.h>
       4              : #include <stdio.h>
       5              : #include <string.h>
       6              : 
       7            4 : bool nut_i64_Matrix_init(nut_i64_Matrix *self, int64_t rows, int64_t cols){
       8            4 :     size_t nbytes;
       9            4 :     if(rows < 0 || cols < 0 || __builtin_mul_overflow(rows, cols, &nbytes) || __builtin_mul_overflow(nbytes, sizeof(int64_t), &nbytes)){
      10              :         return false;
      11              :     }
      12            4 :     if(!(self->buf = malloc(nbytes))){
      13              :         return false;
      14              :     }
      15            4 :     self->rows = rows;
      16            4 :     self->cols = cols;
      17            4 :     return true;
      18              : }
      19              : 
      20            0 : void nut_i64_Matrix_copy(nut_i64_Matrix *restrict dest, const nut_i64_Matrix *restrict src){
      21            0 :     memcpy(dest->buf, src->buf, src->rows*src->cols*sizeof(int64_t));
      22            0 : }
      23              : 
      24            4 : void nut_i64_Matrix_destroy(nut_i64_Matrix *self){
      25            4 :     free(self->buf);
      26            4 :     self->buf = NULL;
      27            4 : }
      28              : 
      29            0 : void nut_i64_Matrix_fprint(const nut_i64_Matrix *restrict self, FILE *file){
      30            0 :     fprintf(file, "[");
      31            0 :     for(int64_t i = 0; i < self->rows; ++i){
      32            0 :         fprintf(file, "[");
      33            0 :         for(int64_t j = 0; j < self->cols; ++j){
      34            0 :             int64_t a = self->buf[self->cols*i + j];
      35            0 :             if(j + 1 == self->cols){
      36            0 :                 fprintf(file, "%"PRIi64, a);
      37              :             }else{
      38            0 :                 fprintf(file, "%"PRIi64", ", a);
      39              :             }
      40              :         }
      41            0 :         fprintf(file, "]");
      42            0 :         if(i + 1 < self->cols){
      43            0 :             fprintf(file, ",\n");
      44              :         }
      45              :     }
      46            0 :     fprintf(file, "]");
      47            0 : }
      48              : 
      49            1 : void nut_i64_Matrix_mul_vec(const nut_i64_Matrix *restrict self, const int64_t vec[restrict static self->cols], int64_t out[restrict static self->rows]){
      50            1 :     if(self->rows){
      51            1 :         memset(out, 0, self->rows*sizeof(int64_t));
      52              :     }
      53            9 :     for(int64_t row = 0; row < self->rows; ++row){
      54           72 :         for(int64_t col = 0; col < self->cols; ++col){
      55           64 :             out[row] += self->buf[self->cols*row + col]*vec[col];
      56              :         }
      57              :     }
      58            1 : }
      59              : 
      60            2 : bool nut_i64_Matrix_fill_I(nut_i64_Matrix *self){
      61            2 :     if(self->rows != self->cols){
      62              :         return false;
      63              :     }
      64            2 :     if(!self->rows){
      65              :         return true;
      66              :     }
      67            2 :     memset(self->buf, 0, self->rows*self->cols*sizeof(int64_t));
      68           15 :     for(int64_t i = 0; i < self->rows; ++i){
      69           13 :         self->buf[self->rows*i + i] = 1;
      70              :     }
      71              :     return true;
      72              : }
      73              : 
      74            2 : void nut_i64_Matrix_fill_short_pascal(nut_i64_Matrix *self){
      75            2 :     if(self->rows*self->cols == 0){
      76              :         return;
      77              :     }
      78           15 :     for(int64_t row = 0; row < self->rows; ++row){
      79           13 :         self->buf[self->cols*row] = 1;
      80           13 :         if(row){
      81           11 :             self->buf[self->cols*row + 1] = row + 1;
      82              :         }
      83           31 :         for(int64_t col = 2; col < row && col < self->cols; ++col){
      84           18 :             self->buf[self->cols*row + col] = self->buf[self->cols*(row - 1) + col - 1] + self->buf[self->cols*(row - 1) + col];
      85              :         }
      86           13 :         if(row < self->cols){
      87           13 :             self->buf[self->cols*row + row] = row + 1;
      88              :         }
      89           13 :         if(row + 1 < self->cols){
      90           11 :             memset(self->buf + self->cols*row + row + 1, 0, (self->cols - row - 1)*sizeof(int64_t));
      91              :         }
      92              :     }
      93              : }
      94              : 
      95           35 : bool nut_i64_Matrix_scale_row(nut_i64_Matrix *self, int64_t row, int64_t col_start, int64_t a){
      96          262 :     for(int64_t i = col_start; i < self->cols; ++i){
      97          227 :         self->buf[self->cols*row + i] *= a;
      98              :     }
      99           35 :     return true;
     100              : }
     101              : 
     102           38 : bool nut_i64_Matrix_addmul_row(nut_i64_Matrix *self, int64_t i, int64_t j, int64_t a){
     103          312 :     for(int64_t k = 0; k < self->cols; ++k){
     104          274 :         self->buf[self->cols*j + k] += a*self->buf[self->cols*i + k];
     105              :     }
     106           38 :     return true;
     107              : }
     108              : 
     109            2 : int64_t nut_i64_Matrix_invert_ltr(nut_i64_Matrix *restrict self, nut_i64_Matrix *restrict out){
     110            2 :     if(self->rows != self->cols || self->rows != out->rows || self->cols != out->rows){
     111            0 :         return 0;
     112              :     }
     113           15 :     for(int64_t i = 0; i < self->rows; ++i){
     114           13 :         if(!self->buf[self->cols*i + i]){
     115              :             return 0;
     116              :         }
     117              :     }
     118            2 :     int64_t denom = 1;
     119            2 :     nut_i64_Matrix_fill_I(out);
     120           15 :     for(int64_t i = 0; i < self->rows; ++i){
     121           13 :         int64_t a = self->buf[self->cols*i + i];
     122           51 :         for(int64_t j = i + 1; j < self->rows; ++j){
     123           38 :             int64_t b = self->buf[self->cols*j + i];
     124           38 :             int64_t g = nut_i64_egcd(a, b, NULL, NULL);
     125           38 :             if(a != g){
     126           11 :                 nut_i64_Matrix_scale_row(self, j, i, a/g);
     127           11 :                 nut_i64_Matrix_scale_row(out, j, 0, a/g);
     128              :             }
     129           38 :             if(b){
     130           38 :                 nut_i64_Matrix_addmul_row(out, i, j, -b/g);
     131              :             }
     132              :         }
     133              :     }
     134           15 :     for(int64_t i = 0; i < self->rows; ++i){
     135           13 :         denom = nut_i64_lcm(denom, self->buf[self->cols*i + i]);
     136              :     }
     137           15 :     for(int64_t i = 0; i < self->rows; ++i){
     138           13 :         nut_i64_Matrix_scale_row(out, i, 0, denom/self->buf[self->cols*i + i]);
     139              :     }
     140              :     return denom;
     141              : }
     142              : 
     143           32 : void nut_i64_Matrix_fill_vandemond_vec(uint64_t x, uint64_t k, uint64_t m, int64_t out[static k + 1]){
     144           32 :     if(m){
     145            0 :         x %= m;
     146              :     }
     147          352 :     for(uint64_t e = 1, xe = x; e <= k; ++e, xe = m ? xe*x%m : xe*x){
     148          160 :         out[e - 1] = xe;
     149              :     }
     150           32 : }
     151              : 
     152            6 : bool nut_u64_ModMatrix_init(nut_u64_ModMatrix *self, uint64_t rows, uint64_t cols, uint64_t modulus){
     153            6 :     size_t nbytes;
     154            6 :     if(__builtin_mul_overflow(rows, cols, &nbytes) || __builtin_mul_overflow(nbytes, sizeof(uint64_t), &nbytes)){
     155              :         return false;
     156              :     }
     157            6 :     if(!(self->buf = malloc(nbytes))){
     158              :         return false;
     159              :     }
     160            6 :     self->rows = rows;
     161            6 :     self->cols = cols;
     162            6 :     self->modulus = modulus;
     163            6 :     return true;
     164              : }
     165              : 
     166            0 : void nut_u64_ModMatrix_copy(nut_u64_ModMatrix *restrict dest, const nut_u64_ModMatrix *restrict src){
     167            0 :     memcpy(dest->buf, src->buf, src->rows*src->cols*sizeof(uint64_t));
     168            0 : }
     169              : 
     170            6 : void nut_u64_ModMatrix_destroy(nut_u64_ModMatrix *self){
     171            6 :     free(self->buf);
     172            6 :     self->buf = NULL;
     173            6 : }
     174              : 
     175            0 : void nut_u64_ModMatrix_fprint(const nut_u64_ModMatrix *restrict self, FILE *file){
     176            0 :     fprintf(file, "[");
     177            0 :     for(uint64_t i = 0; i < self->rows; ++i){
     178            0 :         fprintf(file, "[");
     179            0 :         for(uint64_t j = 0; j < self->cols; ++j){
     180            0 :             uint64_t a = self->buf[self->cols*i + j];
     181            0 :             if(j + 1 == self->cols){
     182            0 :                 fprintf(file, "%"PRIu64, a);
     183              :             }else{
     184            0 :                 fprintf(file, "%"PRIu64", ", a);
     185              :             }
     186              :         }
     187            0 :         fprintf(file, "]");
     188            0 :         if(i + 1 < self->cols){
     189            0 :             fprintf(file, ",\n");
     190              :         }
     191              :     }
     192            0 :     fprintf(file, "] mod %"PRIu64"", self->modulus);
     193            0 : }
     194              : 
     195            3 : bool nut_u64_ModMatrix_fill_I(nut_u64_ModMatrix *self){
     196            3 :     if(self->rows != self->cols){
     197              :         return false;
     198              :     }
     199            3 :     if(!self->rows){
     200              :         return true;
     201              :     }
     202            3 :     memset(self->buf, 0, self->rows*self->cols*sizeof(uint64_t));
     203           20 :     for(uint64_t i = 0; i < self->rows; ++i){
     204           17 :         self->buf[self->rows*i + i] = 1;
     205              :     }
     206              :     return true;
     207              : }
     208              : 
     209            0 : void nut_u64_ModMatrix_mul_vec(const nut_u64_ModMatrix *restrict self, const uint64_t vec[restrict static self->cols], uint64_t out[restrict static self->rows]){
     210            0 :     if(self->rows){
     211            0 :         memset(out, 0, self->rows*sizeof(uint64_t));
     212              :     }
     213            0 :     for(uint64_t row = 0; row < self->rows; ++row){
     214            0 :         for(uint64_t col = 0; col < self->cols; ++col){
     215            0 :             out[row] = (out[row] + self->buf[self->cols*row + col]*vec[col])%self->modulus;
     216              :         }
     217              :     }
     218            0 : }
     219              : 
     220            3 : void nut_u64_ModMatrix_fill_short_pascal(nut_u64_ModMatrix *self){
     221            3 :     if(self->rows*self->cols == 0){
     222              :         return;
     223              :     }
     224           20 :     for(uint64_t row = 0; row < self->rows; ++row){
     225           17 :         self->buf[self->cols*row] = 1;
     226           17 :         if(row){
     227           14 :             self->buf[self->cols*row + 1] = (row + 1)%self->modulus;
     228              :         }
     229           36 :         for(uint64_t col = 2; col < row && col < self->cols; ++col){
     230           19 :             uint64_t v = self->buf[self->cols*(row - 1) + col - 1] + self->buf[self->cols*(row - 1) + col];
     231           19 :             self->buf[self->cols*row + col] = v >= self->modulus ? v - self->modulus : v;
     232              :         }
     233           17 :         if(row < self->cols){
     234           17 :             self->buf[self->cols*row + row] = (row + 1)%self->modulus;
     235              :         }
     236           17 :         if(row + 1 < self->cols){
     237           14 :             memset(self->buf + self->cols*row + row + 1, 0, (self->cols - row - 1)*sizeof(uint64_t));
     238              :         }
     239              :     }
     240              : }
     241              : 
     242           17 : bool nut_u64_ModMatrix_scale_row(nut_u64_ModMatrix *self, uint64_t row, uint64_t col_start, uint64_t a){
     243          122 :     for(uint64_t i = col_start; i < self->cols; ++i){
     244          105 :         self->buf[self->cols*row + i] *= a;
     245          105 :         self->buf[self->cols*row + i] %= self->modulus;
     246              :     }
     247           17 :     return true;
     248              : }
     249              : 
     250           44 : bool nut_u64_ModMatrix_addmul_row(nut_u64_ModMatrix *self, uint64_t i, uint64_t j, uint64_t a){
     251          342 :     for(uint64_t k = 0; k < self->cols; ++k){
     252          298 :         self->buf[self->cols*j + k] += a*self->buf[self->cols*i + k];
     253          298 :         self->buf[self->cols*j + k] %= self->modulus;
     254              :     }
     255           44 :     return true;
     256              : }
     257              : 
     258            3 : bool nut_u64_ModMatrix_invert_ltr(nut_u64_ModMatrix *restrict self, nut_u64_ModMatrix *restrict out){
     259            3 :     if(self->rows != self->cols || self->rows != out->rows || self->cols != out->rows){
     260            0 :         return 0;
     261              :     }
     262           20 :     for(uint64_t i = 0; i < self->rows; ++i){
     263           17 :         if(!self->buf[self->cols*i + i]){
     264              :             return false;
     265              :         }
     266              :     }
     267            3 :     nut_u64_ModMatrix_fill_I(out);
     268           20 :     for(uint64_t i = 0; i < self->rows; ++i){
     269           17 :         uint64_t a = self->buf[self->cols*i + i];
     270           17 :         int64_t _ainv = nut_i64_modinv(a, self->modulus);
     271           17 :         nut_u64_ModMatrix_scale_row(out, i, 0, (uint64_t)_ainv);
     272           61 :         for(uint64_t j = i + 1; j < self->rows; ++j){
     273           44 :             uint64_t b = self->buf[self->cols*j + i];
     274           44 :             if(b){
     275           44 :                 nut_u64_ModMatrix_addmul_row(out, i, j, self->modulus - b);
     276              :             }
     277              :         }
     278              :     }
     279              :     return true;
     280              : }
     281              : 
     282           30 : void nut_u64_Matrix_fill_vandemond_vec(uint64_t x, uint64_t k, uint64_t m, uint64_t out[static k + 1]){
     283          200 :     for(uint64_t e = 1, xe = x; e <= k; ++e, xe = xe*x%m){
     284          170 :         out[e - 1] = xe;
     285              :     }
     286           30 : }
     287              : 
        

Generated by: LCOV version 2.0-1