About Docs Source
LCOV - code coverage report
Current view: top level - src/lib/crater - kd_tree.c (source / functions) Hit Total Coverage
Test: unnamed Lines: 118 213 55.4 %
Date: 2024-02-13 04:57:17 Functions: 13 20 65.0 %

          Line data    Source code
       1             : #include <stdlib.h>
       2             : #include <string.h>
       3             : #include <inttypes.h>
       4             : #include <limits.h>
       5             : #include <math.h>
       6             : 
       7             : #include <crater/kd_tree.h>
       8             : #include <crater/kd_check.h>
       9             : #include <crater/minmax_heap.h>
      10             : 
      11      845103 : static int cmp_depth_i64cu(const cr8r_base_ft *_ft, const void *_a, const void *_b){
      12      845103 :     const cr8r_kd_ft *ft = (cr8r_kd_ft*)_ft;
      13      845103 :     const int64_t *a = _a;
      14      845103 :     const int64_t *b = _b;
      15      845103 :     uint64_t depth = (uint64_t)ft->super.base.data;
      16      845103 :     uint64_t idx = depth%ft->dim;
      17      845103 :     int64_t key_a = a[idx];
      18      845103 :     int64_t key_b = b[idx];
      19      845103 :     if(key_a < key_b){
      20             :         return -1;
      21      342695 :     }else if(key_a > key_b){
      22      328480 :         return 1;
      23             :     }
      24             :     return 0;
      25             : }
      26             : 
      27           0 : static int cmp_depth_i64sp(const cr8r_base_ft *_ft, const void *_a, const void *_b){
      28           0 :     const cr8r_kd_ft *ft = (cr8r_kd_ft*)_ft;
      29           0 :     const int64_t *a = _a;
      30           0 :     const int64_t *b = _b;
      31           0 :     uint64_t depth = (uint64_t)ft->super.base.data;
      32           0 :     uint64_t idx = depth%ft->dim;
      33           0 :     if(idx == ft->dim - 1){
      34           0 :         int64_t a_x = a[0], a_y = a[1];
      35           0 :         int64_t b_x = b[0], b_y = b[1];
      36             :         // if a and b are in opposite half planes
      37           0 :         if((a_y >= 0) != (b_y >= 0)){
      38           0 :             if(a_y < 0){
      39             :                 #ifdef CR8R_KDSP_ORIGIN_IS_INFIMUM
      40             :                     return 1;
      41             :                 #else // otherwise (0, 0) == any point
      42           0 :                     return b_x || b_y;// return 0 (==) if b is (0, 0), else 1 (>)
      43             :                 #endif
      44             :             }
      45             :             #ifdef CR8R_KDSP_ORIGIN_IS_INFIMUM
      46             :                 return -1;
      47             :             #else // otherwise (0, 0) == any point
      48           0 :                 return -(a_x || a_y);// return 0 (==) if a is (0, 0), else -1 (<)
      49             :             #endif
      50             :         }
      51           0 :         int64_t crossp = a_x*b_y - b_x*a_y;
      52           0 :         if(crossp > 0){
      53             :             return -1;
      54           0 :         }else if(crossp < 0){
      55             :             return 1;
      56             :         }
      57             :         #ifdef CR8R_KDSP_ORIGIN_IS_INFIMUM
      58             :             if(!a_x && !a_y){// if a is (0, 0)
      59             :                 return -(b_x || b_y);// return 0 (==) if b is (0, 0), else -1 (<)
      60             :             }
      61             :             return !(b_x || b_y);// return 1 (>) if b is (0, 0), else 0 (==)
      62             :         #else
      63           0 :             return 0;
      64             :         #endif
      65             :     }
      66           0 :     int64_t key_a = a[2 + idx];
      67           0 :     int64_t key_b = b[2 + idx];
      68           0 :     if(key_a < key_b){
      69             :         return -1;
      70           0 :     }else if(key_a > key_b){
      71           0 :         return 1;
      72             :     }
      73             :     return 0;
      74             : }
      75             : 
      76           0 : static void split_i64sp(const cr8r_kd_ft *ft, const void *_self, const void *_root_pt, void *_o1, void *_o2){
      77           0 :     const cr8r_kdwin_s2i64 *self = _self;
      78           0 :     const int64_t *root = _root_pt;
      79           0 :     cr8r_kdwin_s2i64 *o1 = _o1;
      80           0 :     cr8r_kdwin_s2i64 *o2 = _o2;
      81           0 :     uint64_t idx = (uint64_t)ft->super.base.data%ft->dim;
      82           0 :     memcpy(o1, self, sizeof(cr8r_kdwin_s2i64));
      83           0 :     memcpy(o2, self, sizeof(cr8r_kdwin_s2i64));
      84           0 :     if(idx == ft->dim - 1){
      85           0 :         if(!root[0] || !root[1]){
      86           0 :             memcpy(o1->tr, root, 2*sizeof(int64_t));
      87           0 :             memcpy(o2->bl, root, 2*sizeof(int64_t));
      88             :         }
      89             :     }else{
      90           0 :         o1->tr[2 + idx] = root[2 + idx];
      91           0 :         o2->bl[2 + idx] = root[2 + idx];
      92             :     }
      93           0 : }
      94             : 
      95           0 : static void update_i64sp(const cr8r_kd_ft *_ft, void *_self, const void *_pt){
      96           0 :     cr8r_kdwin_s2i64 *self = _self;
      97           0 :     const int64_t *pt = _pt;
      98           0 :     cr8r_kd_ft ft = *_ft;
      99           0 :     for(uint64_t i = 0; i < ft.dim - 1; ++i){
     100           0 :         ft.super.base.data = (void*)i;
     101           0 :         if(ft.super.cmp(&ft.super.base, pt, self->bl) < 0){
     102           0 :             self->bl[2 + i] = pt[2 + i];
     103           0 :         }else if(ft.super.cmp(&ft.super.base, pt, self->tr) > 0){
     104           0 :             self->tr[2 + i] = pt[2 + i];
     105             :         }
     106             :     }
     107           0 :     ft.super.base.data = (void*)(ft.dim - 1);
     108           0 :     if(ft.super.cmp(&ft.super.base, pt, self->bl) < 0){
     109           0 :         memcpy(self->bl, pt, 2*sizeof(int64_t));
     110           0 :     }else if(ft.super.cmp(&ft.super.base, pt, self->tr) > 0){
     111           0 :         memcpy(self->tr, pt, 2*sizeof(int64_t));
     112             :     }
     113           0 : }
     114             : 
     115           0 : static double min_sqdist_i64sp(const cr8r_kd_ft *_ft, const void *_self, const void *_pt){
     116             :     /* The min sqdist on a sphere is the same in the last n-2 dimensions,
     117             :     but the first 2 dimensions which are treated as an angle must have their
     118             :     sqdist computed differently
     119             :     The distance beween x1, x2 and l1, l2 can be found by the law of cosines
     120             :     sqdist = 2*x1**2 + 2*x2**2 - 2*(x1**2 + x2**2)*cos
     121             :     Here, x1**2 + x2**2 is the square of the radius of the circle in the hypersphere
     122             :     where x1, x2 lives.  The triangle in question in the law of cosines
     123             :     is in the circle with x1, x2 and is an isosolese triangle through the center
     124             :     and two points on the circumference
     125             :     Thus the two known sides are the radius of that circle and the cos can be
     126             :     found through the dot product
     127             :     sqdist = 2*(x1**2 + x2**2 - (x1*l1 + x2*l2)/(l1**2 + l2**2))
     128             :     Remember, this is just the component of the sqdist in the first two dimensions
     129             :     Also notice that this is not necessarily an integer, so we will use floor division
     130             :     and accept that we explore some windows that are actually too far to contribute to the
     131             :     k closest points
     132             :     
     133             :     Now we just have to find whether x1, x2 is closer to l1, l2, r1, r2, or is between them
     134             :     */
     135           0 :     const cr8r_kdwin_s2i64 *self = _self;
     136           0 :     const int64_t *pt = _pt;
     137           0 :     cr8r_kd_ft ft = *_ft;
     138           0 :     double sqdist = -1;
     139           0 :     ft.super.base.data = (void*)(ft.dim - 1);
     140           0 :     if(ft.super.cmp(&ft.super.base, self->bl, self->tr) > 0){
     141           0 :         if(ft.super.cmp(&ft.super.base, self->bl, pt) <= 0){
     142           0 :             sqdist = 0;
     143             :         }
     144           0 :         if(ft.super.cmp(&ft.super.base, pt, self->tr) <= 0){
     145             :             sqdist = 0;
     146             :         }
     147           0 :     }else if(!self->bl[1] && !self->tr[1]){
     148             :         sqdist = 0;
     149             :     }else{
     150           0 :         if(ft.super.cmp(&ft.super.base, self->bl, pt) <= 0 && ft.super.cmp(&ft.super.base, pt, self->tr) <= 0){
     151             :             sqdist = 0;
     152             :         }
     153             :     }
     154           0 :     if(sqdist == -1){
     155           0 :         int64_t lcrs = self->bl[0]*self->bl[0] + self->bl[1]*self->bl[1];
     156           0 :         int64_t rcrs = self->tr[0]*self->tr[0] + self->tr[1]*self->tr[1];
     157           0 :         int64_t acrs = pt[0]*pt[0] + pt[1]*pt[1];
     158           0 :         int64_t lda = pt[0]*self->bl[0] + pt[1]*self->bl[1];
     159           0 :         int64_t rda = pt[0]*self->tr[0] + pt[1]*self->tr[1];
     160           0 :         if(lcrs*rda < rcrs*lda){
     161           0 :             sqdist = 2*(acrs - (double)rda/rcrs);
     162             :         }else{
     163           0 :             sqdist = 2*(acrs - (double)lda/lcrs);
     164             :         }
     165             :     }
     166           0 :     for(uint64_t i = 2; i <= ft.dim; ++i){
     167           0 :         int64_t axdist = self->bl[i] - pt[i];
     168           0 :         if(pt[i] - self->tr[i] > axdist){
     169           0 :             axdist = pt[i] - self->tr[i];
     170             :         }
     171           0 :         sqdist += axdist > 0 ? axdist*axdist : 0;
     172             :     }
     173           0 :     return sqdist;
     174             : }
     175             : 
     176           0 : static double sqdist_i64sp(const cr8r_kd_ft *_ft, const void *_a, const void *_b){
     177           0 :     double res = 0;
     178           0 :     const int64_t *a = _a, *b = _b;
     179           0 :     for(uint64_t i = 0; i < _ft->dim + 1; ++i){
     180           0 :         res += (a[i] - b[i])*(a[i] - b[i]);
     181             :     }
     182           0 :     return res;
     183             : }
     184             : 
     185             : 
     186             : 
     187      250000 : static void split_i64cu(const cr8r_kd_ft *ft, const void *_self, const void *_root_pt, void *_o1, void *_o2){
     188      250000 :     const cr8r_kdwin_s2i64 *self = _self;
     189      250000 :     const int64_t *root = _root_pt;
     190      250000 :     cr8r_kdwin_s2i64 *o1 = _o1;
     191      250000 :     cr8r_kdwin_s2i64 *o2 = _o2;
     192      250000 :     uint64_t idx = (uint64_t)ft->super.base.data%ft->dim;
     193      250000 :     memcpy(o1, self, sizeof(cr8r_kdwin_s2i64));
     194      250000 :     memcpy(o2, self, sizeof(cr8r_kdwin_s2i64));
     195      250000 :     o1->tr[idx] = root[idx];
     196      250000 :     o2->bl[idx] = root[idx];
     197      250000 : }
     198             : 
     199        4995 : static void update_i64cu(const cr8r_kd_ft *_ft, void *_self, const void *_pt){
     200        4995 :     cr8r_kdwin_s2i64 *self = _self;
     201        4995 :     const int64_t *pt = _pt;
     202        4995 :     cr8r_kd_ft ft = *_ft;
     203       19980 :     for(uint64_t i = 0; i < ft.dim; ++i){
     204       14985 :         ft.super.base.data = (void*)i;
     205       14985 :         if(ft.super.cmp(&ft.super.base, pt, self->bl) < 0){
     206          99 :             self->bl[i] = pt[i];
     207       14886 :         }else if(ft.super.cmp(&ft.super.base, pt, self->tr) > 0){
     208          80 :             self->tr[i] = pt[i];
     209             :         }
     210             :     }
     211        4995 : }
     212             : 
     213      475000 : static double min_sqdist_i64cu(const cr8r_kd_ft *_ft, const void *_self, const void *_pt){
     214      475000 :     const cr8r_kdwin_s2i64 *self = _self;
     215      475000 :     const int64_t *pt = _pt;
     216      475000 :     double sqdist = 0;
     217     1900000 :     for(uint64_t i = 0; i < _ft->dim; ++i){
     218     1425000 :         int64_t axdist = self->bl[i] - pt[i];
     219     1425000 :         if(pt[i] - self->tr[i] > axdist){
     220      726750 :             axdist = pt[i] - self->tr[i];
     221             :         }
     222     1425000 :         sqdist += axdist > 0 ? axdist*axdist : 0;
     223             :     }
     224      475000 :     return sqdist;
     225             : }
     226             : 
     227    17133372 : static double sqdist_i64cu(const cr8r_kd_ft *_ft, const void *_a, const void *_b){
     228    17133372 :     double res = 0;
     229    17133372 :     const int64_t *a = _a, *b = _b;
     230    68533488 :     for(uint64_t i = 0; i < _ft->dim; ++i){
     231    51400116 :         res += (a[i] - b[i])*(a[i] - b[i]);
     232             :     }
     233    17133372 :     return res;
     234             : }
     235             : 
     236             : 
     237             : 
     238           0 : bool cr8r_kdwin_init_i64sp(cr8r_kdwin_s2i64 *self, const int64_t bl[3], const int64_t tr[3]){
     239           0 :     memcpy(self->bl, bl, 3*sizeof(int64_t));
     240           0 :     memcpy(self->tr, tr, 3*sizeof(int64_t));
     241           0 :     return 1;
     242             : }
     243             : 
     244           0 : bool cr8r_kdwin_init_i64cu(cr8r_kdwin_s2i64 *self, const int64_t bl[3], const int64_t tr[3]){
     245           0 :     memcpy(self->bl, bl, 3*sizeof(int64_t));
     246           0 :     memcpy(self->tr, tr, 3*sizeof(int64_t));
     247           0 :     return 1;
     248             : }
     249             : 
     250           5 : bool cr8r_kdwin_bounding_i64x3(cr8r_kdwin_s2i64 *self, const cr8r_vec *ents, const cr8r_kd_ft *ft){
     251           5 :     if(!ents->len){
     252             :         return 0;
     253             :     }
     254           5 :     memcpy(self->bl, ents->buf, ft->super.base.size);
     255           5 :     memcpy(self->tr, ents->buf, ft->super.base.size);
     256        5000 :     for(uint64_t i = 1; i < ents->len; ++i){
     257        4995 :         const void *ent = ents->buf + i*ft->super.base.size;
     258        4995 :         ft->update(ft, self, ent);
     259             :     }
     260             :     return 1;
     261             : }
     262             : 
     263        5005 : bool cr8r_kd_ify(cr8r_vec *self, cr8r_kd_ft *_ft, uint64_t a, uint64_t b){
     264        5005 :     cr8r_kd_ft ft = *_ft;
     265       10005 :     while(b > a){
     266        5000 :         uint64_t mid_idx = (a + b)/2;
     267        5000 :         void *piv = cr8r_vec_ith(self, &ft.super, a, b, mid_idx - a);
     268        5000 :         if(!piv){
     269             :             return 0;
     270             :         }
     271        5000 :         piv = cr8r_vec_partition_with_median(self, &ft.super, a, b, piv);
     272             : #ifdef DEBUG
     273        5000 :         if(!piv || !cr8r_kd_check_layer(self, &ft, a, b)){
     274           0 :             __builtin_trap();
     275             :         }
     276             : #endif
     277             :         // increment depth
     278        5000 :         ++*(uint64_t*)&ft.super.base.data;
     279        5000 :         if(!cr8r_kd_ify(self, &ft, mid_idx + 1, b)){
     280             :             return 0;
     281             :         }
     282             :         b = mid_idx;
     283             :     }
     284             :     return 1;
     285             : }
     286             : 
     287      250250 : cr8r_walk_decision cr8r_kd_walk_r(cr8r_vec *self, const cr8r_kd_ft *_ft, void *bounds, cr8r_kdvisitor visitor, void *data, uint64_t a, uint64_t b){
     288      250250 :     cr8r_kd_ft ft = *_ft;
     289      250250 :     char sub0[ft.bounds_size];
     290      250250 :     char sub1[ft.bounds_size];
     291      500250 :     while(b > a){
     292      250000 :         uint64_t mid_idx = (a + b)/2;
     293      250000 :         void *ent =  self->buf + mid_idx*ft.super.base.size;
     294      250000 :         cr8r_walk_decision decision = visitor(&ft, bounds, ent, data);
     295      250000 :         if(decision == CR8R_WALK_STOP){
     296             :             return decision;
     297      250000 :         }else if(decision == CR8R_WALK_SKIP_CHILDREN){
     298             :             return CR8R_WALK_CONTINUE;
     299             :         }
     300      250000 :         ft.split(&ft, bounds, ent, sub0, sub1);
     301             :         // increment depth
     302      250000 :         ++*(uint64_t*)&ft.super.base.data;
     303      250000 :         decision = cr8r_kd_walk_r(self, &ft, sub0, visitor, data, a, mid_idx);
     304      250000 :         if(decision == CR8R_WALK_STOP){
     305             :             return decision;
     306             :         }
     307      250000 :         a = mid_idx + 1;
     308      250000 :         memcpy(bounds, sub1, ft.bounds_size);
     309             :     }
     310             :     return CR8R_WALK_CONTINUE;
     311             : }
     312             : 
     313         250 : void cr8r_kd_walk(cr8r_vec *self, const cr8r_kd_ft *ft, const void *_bounds, cr8r_kdvisitor visitor, void *data){
     314         250 :     char bounds[ft->bounds_size];
     315         250 :     memcpy(bounds, _bounds, ft->bounds_size);
     316         250 :     cr8r_kd_walk_r(self, ft, bounds, visitor, data, 0, self->len);
     317         250 : }
     318             : 
     319      500000 : inline static cr8r_walk_decision k_closest_visitor(cr8r_kd_ft *ft, const void *bounds, void *ent, void *_data){
     320      500000 :     cr8r_kd_k_closest_state *data = _data;
     321      500000 :     char tmp[ft->super.base.size];
     322      500000 :     if(data->ents->len < data->k){
     323       25000 :         cr8r_mmheap_push(data->ents, &data->ft.super, ent);
     324             :     }else{
     325      475000 :         cr8r_mmheap_pushpop_max(data->ents, &data->ft.super, ent, tmp);
     326      475000 :         data->max_sqdist = ft->sqdist(ft, data->pt, cr8r_mmheap_peek_max(data->ents, &data->ft.super));
     327             :     }
     328      500000 :     if(isinf(data->max_sqdist) || data->max_sqdist > ft->min_sqdist(ft, bounds, ent)){
     329      500000 :         return CR8R_WALK_CONTINUE;
     330             :     }
     331             :     return CR8R_WALK_SKIP_CHILDREN;
     332             : }
     333             : 
     334     8316686 : int cr8r_default_cmp_kd_kcs_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b){
     335     8316686 :     const cr8r_kd_ft *ft = (const cr8r_kd_ft*)_ft;
     336     8316686 :     const cr8r_kd_k_closest_state *data = CR8R_OUTER(ft, cr8r_kd_k_closest_state, ft);
     337     8316686 :     double a_sqdist = ft->sqdist(ft, data->pt, a);
     338     8316686 :     double b_sqdist = ft->sqdist(ft, data->pt, b);
     339     8316686 :     if(a_sqdist < b_sqdist){
     340             :         return -1;
     341     3722987 :     }else if(a_sqdist > b_sqdist){
     342     3722961 :         return 1;
     343             :     }
     344             :     return 0;
     345             : }
     346             : 
     347         250 : bool cr8r_kd_k_closest(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
     348         250 :     cr8r_vec_clear(out, &ft->super);
     349         250 :     if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){
     350             :         return false;
     351             :     }
     352         250 :     cr8r_kd_k_closest_state data = {
     353             :         .ents = out,
     354             :         .ft = *ft,
     355             :         .pt = pt,
     356             :         .k = k,
     357             :         .max_sqdist = INFINITY
     358             :     };
     359         250 :     data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist;
     360         250 :     cr8r_kd_walk(self, ft, bounds, k_closest_visitor, &data);
     361         250 :     return true;
     362             : }
     363             : 
     364         250 : bool cr8r_kd_k_closest_naive(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
     365         250 :     cr8r_vec_clear(out, &ft->super);
     366         250 :     if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){
     367             :         return false;
     368             :     }
     369         250 :     cr8r_kd_k_closest_state data = {
     370             :         .ents = out,
     371             :         .ft = *ft,
     372             :         .pt = pt,
     373             :         .k = k,
     374             :         .max_sqdist = INFINITY
     375             :     };
     376         250 :     data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist;
     377      250250 :     for(uint64_t i = 0; i < self->len; ++i){
     378      250000 :         k_closest_visitor(ft, bounds, self->buf + i*ft->super.base.size, &data);
     379             :     }
     380             :     return true;
     381             : }
     382             : 
     383             : cr8r_kd_ft cr8r_kdft_s2i64 = {
     384             :     .super.base.size = 3*sizeof(int64_t),
     385             :     .super.new_size = cr8r_default_new_size,
     386             :     .super.resize = cr8r_default_resize,
     387             :     .super.cmp = cmp_depth_i64sp,
     388             :     .super.swap = cr8r_default_swap,
     389             :     .dim = 2,
     390             :     .bounds_size = 6*sizeof(int64_t),
     391             :     .split = split_i64sp,
     392             :     .update = update_i64sp,
     393             :     .min_sqdist = min_sqdist_i64sp,
     394             :     .sqdist = sqdist_i64sp
     395             : };
     396             : 
     397             : cr8r_kd_ft cr8r_kdft_c3i64 = {
     398             :     .super.base.size = 3*sizeof(int64_t),
     399             :     .super.new_size = cr8r_default_new_size,
     400             :     .super.resize = cr8r_default_resize,
     401             :     .super.cmp = cmp_depth_i64cu,
     402             :     .super.swap = cr8r_default_swap,
     403             :     .dim = 3,
     404             :     .bounds_size = 6*sizeof(int64_t),
     405             :     .split = split_i64cu,
     406             :     .update = update_i64cu,
     407             :     .min_sqdist = min_sqdist_i64cu,
     408             :     .sqdist = sqdist_i64cu
     409             : };
     410             : 

Generated by: LCOV version 1.14