Line data Source code
1 : #include <stdlib.h>
2 : #include <string.h>
3 : #include <inttypes.h>
4 : #include <limits.h>
5 : #include <math.h>
6 :
7 : #include <crater/kd_tree.h>
8 : #include <crater/kd_check.h>
9 : #include <crater/minmax_heap.h>
10 :
11 845103 : static int cmp_depth_i64cu(const cr8r_base_ft *_ft, const void *_a, const void *_b){
12 845103 : const cr8r_kd_ft *ft = (cr8r_kd_ft*)_ft;
13 845103 : const int64_t *a = _a;
14 845103 : const int64_t *b = _b;
15 845103 : uint64_t depth = (uint64_t)ft->super.base.data;
16 845103 : uint64_t idx = depth%ft->dim;
17 845103 : int64_t key_a = a[idx];
18 845103 : int64_t key_b = b[idx];
19 845103 : if(key_a < key_b){
20 : return -1;
21 342695 : }else if(key_a > key_b){
22 328480 : return 1;
23 : }
24 : return 0;
25 : }
26 :
27 0 : static int cmp_depth_i64sp(const cr8r_base_ft *_ft, const void *_a, const void *_b){
28 0 : const cr8r_kd_ft *ft = (cr8r_kd_ft*)_ft;
29 0 : const int64_t *a = _a;
30 0 : const int64_t *b = _b;
31 0 : uint64_t depth = (uint64_t)ft->super.base.data;
32 0 : uint64_t idx = depth%ft->dim;
33 0 : if(idx == ft->dim - 1){
34 0 : int64_t a_x = a[0], a_y = a[1];
35 0 : int64_t b_x = b[0], b_y = b[1];
36 : // if a and b are in opposite half planes
37 0 : if((a_y >= 0) != (b_y >= 0)){
38 0 : if(a_y < 0){
39 : #ifdef CR8R_KDSP_ORIGIN_IS_INFIMUM
40 : return 1;
41 : #else // otherwise (0, 0) == any point
42 0 : return b_x || b_y;// return 0 (==) if b is (0, 0), else 1 (>)
43 : #endif
44 : }
45 : #ifdef CR8R_KDSP_ORIGIN_IS_INFIMUM
46 : return -1;
47 : #else // otherwise (0, 0) == any point
48 0 : return -(a_x || a_y);// return 0 (==) if a is (0, 0), else -1 (<)
49 : #endif
50 : }
51 0 : int64_t crossp = a_x*b_y - b_x*a_y;
52 0 : if(crossp > 0){
53 : return -1;
54 0 : }else if(crossp < 0){
55 : return 1;
56 : }
57 : #ifdef CR8R_KDSP_ORIGIN_IS_INFIMUM
58 : if(!a_x && !a_y){// if a is (0, 0)
59 : return -(b_x || b_y);// return 0 (==) if b is (0, 0), else -1 (<)
60 : }
61 : return !(b_x || b_y);// return 1 (>) if b is (0, 0), else 0 (==)
62 : #else
63 0 : return 0;
64 : #endif
65 : }
66 0 : int64_t key_a = a[2 + idx];
67 0 : int64_t key_b = b[2 + idx];
68 0 : if(key_a < key_b){
69 : return -1;
70 0 : }else if(key_a > key_b){
71 0 : return 1;
72 : }
73 : return 0;
74 : }
75 :
76 0 : static void split_i64sp(const cr8r_kd_ft *ft, const void *_self, const void *_root_pt, void *_o1, void *_o2){
77 0 : const cr8r_kdwin_s2i64 *self = _self;
78 0 : const int64_t *root = _root_pt;
79 0 : cr8r_kdwin_s2i64 *o1 = _o1;
80 0 : cr8r_kdwin_s2i64 *o2 = _o2;
81 0 : uint64_t idx = (uint64_t)ft->super.base.data%ft->dim;
82 0 : memcpy(o1, self, sizeof(cr8r_kdwin_s2i64));
83 0 : memcpy(o2, self, sizeof(cr8r_kdwin_s2i64));
84 0 : if(idx == ft->dim - 1){
85 0 : if(!root[0] || !root[1]){
86 0 : memcpy(o1->tr, root, 2*sizeof(int64_t));
87 0 : memcpy(o2->bl, root, 2*sizeof(int64_t));
88 : }
89 : }else{
90 0 : o1->tr[2 + idx] = root[2 + idx];
91 0 : o2->bl[2 + idx] = root[2 + idx];
92 : }
93 0 : }
94 :
95 0 : static void update_i64sp(const cr8r_kd_ft *_ft, void *_self, const void *_pt){
96 0 : cr8r_kdwin_s2i64 *self = _self;
97 0 : const int64_t *pt = _pt;
98 0 : cr8r_kd_ft ft = *_ft;
99 0 : for(uint64_t i = 0; i < ft.dim - 1; ++i){
100 0 : ft.super.base.data = (void*)i;
101 0 : if(ft.super.cmp(&ft.super.base, pt, self->bl) < 0){
102 0 : self->bl[2 + i] = pt[2 + i];
103 0 : }else if(ft.super.cmp(&ft.super.base, pt, self->tr) > 0){
104 0 : self->tr[2 + i] = pt[2 + i];
105 : }
106 : }
107 0 : ft.super.base.data = (void*)(ft.dim - 1);
108 0 : if(ft.super.cmp(&ft.super.base, pt, self->bl) < 0){
109 0 : memcpy(self->bl, pt, 2*sizeof(int64_t));
110 0 : }else if(ft.super.cmp(&ft.super.base, pt, self->tr) > 0){
111 0 : memcpy(self->tr, pt, 2*sizeof(int64_t));
112 : }
113 0 : }
114 :
115 0 : static double min_sqdist_i64sp(const cr8r_kd_ft *_ft, const void *_self, const void *_pt){
116 : /* The min sqdist on a sphere is the same in the last n-2 dimensions,
117 : but the first 2 dimensions which are treated as an angle must have their
118 : sqdist computed differently
119 : The distance beween x1, x2 and l1, l2 can be found by the law of cosines
120 : sqdist = 2*x1**2 + 2*x2**2 - 2*(x1**2 + x2**2)*cos
121 : Here, x1**2 + x2**2 is the square of the radius of the circle in the hypersphere
122 : where x1, x2 lives. The triangle in question in the law of cosines
123 : is in the circle with x1, x2 and is an isosolese triangle through the center
124 : and two points on the circumference
125 : Thus the two known sides are the radius of that circle and the cos can be
126 : found through the dot product
127 : sqdist = 2*(x1**2 + x2**2 - (x1*l1 + x2*l2)/(l1**2 + l2**2))
128 : Remember, this is just the component of the sqdist in the first two dimensions
129 : Also notice that this is not necessarily an integer, so we will use floor division
130 : and accept that we explore some windows that are actually too far to contribute to the
131 : k closest points
132 :
133 : Now we just have to find whether x1, x2 is closer to l1, l2, r1, r2, or is between them
134 : */
135 0 : const cr8r_kdwin_s2i64 *self = _self;
136 0 : const int64_t *pt = _pt;
137 0 : cr8r_kd_ft ft = *_ft;
138 0 : double sqdist = -1;
139 0 : ft.super.base.data = (void*)(ft.dim - 1);
140 0 : if(ft.super.cmp(&ft.super.base, self->bl, self->tr) > 0){
141 0 : if(ft.super.cmp(&ft.super.base, self->bl, pt) <= 0){
142 0 : sqdist = 0;
143 : }
144 0 : if(ft.super.cmp(&ft.super.base, pt, self->tr) <= 0){
145 : sqdist = 0;
146 : }
147 0 : }else if(!self->bl[1] && !self->tr[1]){
148 : sqdist = 0;
149 : }else{
150 0 : if(ft.super.cmp(&ft.super.base, self->bl, pt) <= 0 && ft.super.cmp(&ft.super.base, pt, self->tr) <= 0){
151 : sqdist = 0;
152 : }
153 : }
154 0 : if(sqdist == -1){
155 0 : int64_t lcrs = self->bl[0]*self->bl[0] + self->bl[1]*self->bl[1];
156 0 : int64_t rcrs = self->tr[0]*self->tr[0] + self->tr[1]*self->tr[1];
157 0 : int64_t acrs = pt[0]*pt[0] + pt[1]*pt[1];
158 0 : int64_t lda = pt[0]*self->bl[0] + pt[1]*self->bl[1];
159 0 : int64_t rda = pt[0]*self->tr[0] + pt[1]*self->tr[1];
160 0 : if(lcrs*rda < rcrs*lda){
161 0 : sqdist = 2*(acrs - (double)rda/rcrs);
162 : }else{
163 0 : sqdist = 2*(acrs - (double)lda/lcrs);
164 : }
165 : }
166 0 : for(uint64_t i = 2; i <= ft.dim; ++i){
167 0 : int64_t axdist = self->bl[i] - pt[i];
168 0 : if(pt[i] - self->tr[i] > axdist){
169 0 : axdist = pt[i] - self->tr[i];
170 : }
171 0 : sqdist += axdist > 0 ? axdist*axdist : 0;
172 : }
173 0 : return sqdist;
174 : }
175 :
176 0 : static double sqdist_i64sp(const cr8r_kd_ft *_ft, const void *_a, const void *_b){
177 0 : double res = 0;
178 0 : const int64_t *a = _a, *b = _b;
179 0 : for(uint64_t i = 0; i < _ft->dim + 1; ++i){
180 0 : res += (a[i] - b[i])*(a[i] - b[i]);
181 : }
182 0 : return res;
183 : }
184 :
185 :
186 :
187 250000 : static void split_i64cu(const cr8r_kd_ft *ft, const void *_self, const void *_root_pt, void *_o1, void *_o2){
188 250000 : const cr8r_kdwin_s2i64 *self = _self;
189 250000 : const int64_t *root = _root_pt;
190 250000 : cr8r_kdwin_s2i64 *o1 = _o1;
191 250000 : cr8r_kdwin_s2i64 *o2 = _o2;
192 250000 : uint64_t idx = (uint64_t)ft->super.base.data%ft->dim;
193 250000 : memcpy(o1, self, sizeof(cr8r_kdwin_s2i64));
194 250000 : memcpy(o2, self, sizeof(cr8r_kdwin_s2i64));
195 250000 : o1->tr[idx] = root[idx];
196 250000 : o2->bl[idx] = root[idx];
197 250000 : }
198 :
199 4995 : static void update_i64cu(const cr8r_kd_ft *_ft, void *_self, const void *_pt){
200 4995 : cr8r_kdwin_s2i64 *self = _self;
201 4995 : const int64_t *pt = _pt;
202 4995 : cr8r_kd_ft ft = *_ft;
203 19980 : for(uint64_t i = 0; i < ft.dim; ++i){
204 14985 : ft.super.base.data = (void*)i;
205 14985 : if(ft.super.cmp(&ft.super.base, pt, self->bl) < 0){
206 99 : self->bl[i] = pt[i];
207 14886 : }else if(ft.super.cmp(&ft.super.base, pt, self->tr) > 0){
208 80 : self->tr[i] = pt[i];
209 : }
210 : }
211 4995 : }
212 :
213 475000 : static double min_sqdist_i64cu(const cr8r_kd_ft *_ft, const void *_self, const void *_pt){
214 475000 : const cr8r_kdwin_s2i64 *self = _self;
215 475000 : const int64_t *pt = _pt;
216 475000 : double sqdist = 0;
217 1900000 : for(uint64_t i = 0; i < _ft->dim; ++i){
218 1425000 : int64_t axdist = self->bl[i] - pt[i];
219 1425000 : if(pt[i] - self->tr[i] > axdist){
220 726750 : axdist = pt[i] - self->tr[i];
221 : }
222 1425000 : sqdist += axdist > 0 ? axdist*axdist : 0;
223 : }
224 475000 : return sqdist;
225 : }
226 :
227 17133372 : static double sqdist_i64cu(const cr8r_kd_ft *_ft, const void *_a, const void *_b){
228 17133372 : double res = 0;
229 17133372 : const int64_t *a = _a, *b = _b;
230 68533488 : for(uint64_t i = 0; i < _ft->dim; ++i){
231 51400116 : res += (a[i] - b[i])*(a[i] - b[i]);
232 : }
233 17133372 : return res;
234 : }
235 :
236 :
237 :
238 0 : bool cr8r_kdwin_init_i64sp(cr8r_kdwin_s2i64 *self, const int64_t bl[3], const int64_t tr[3]){
239 0 : memcpy(self->bl, bl, 3*sizeof(int64_t));
240 0 : memcpy(self->tr, tr, 3*sizeof(int64_t));
241 0 : return 1;
242 : }
243 :
244 0 : bool cr8r_kdwin_init_i64cu(cr8r_kdwin_s2i64 *self, const int64_t bl[3], const int64_t tr[3]){
245 0 : memcpy(self->bl, bl, 3*sizeof(int64_t));
246 0 : memcpy(self->tr, tr, 3*sizeof(int64_t));
247 0 : return 1;
248 : }
249 :
250 5 : bool cr8r_kdwin_bounding_i64x3(cr8r_kdwin_s2i64 *self, const cr8r_vec *ents, const cr8r_kd_ft *ft){
251 5 : if(!ents->len){
252 : return 0;
253 : }
254 5 : memcpy(self->bl, ents->buf, ft->super.base.size);
255 5 : memcpy(self->tr, ents->buf, ft->super.base.size);
256 5000 : for(uint64_t i = 1; i < ents->len; ++i){
257 4995 : const void *ent = ents->buf + i*ft->super.base.size;
258 4995 : ft->update(ft, self, ent);
259 : }
260 : return 1;
261 : }
262 :
263 5005 : bool cr8r_kd_ify(cr8r_vec *self, cr8r_kd_ft *_ft, uint64_t a, uint64_t b){
264 5005 : cr8r_kd_ft ft = *_ft;
265 10005 : while(b > a){
266 5000 : uint64_t mid_idx = (a + b)/2;
267 5000 : void *piv = cr8r_vec_ith(self, &ft.super, a, b, mid_idx - a);
268 5000 : if(!piv){
269 : return 0;
270 : }
271 5000 : piv = cr8r_vec_partition_with_median(self, &ft.super, a, b, piv);
272 : #ifdef DEBUG
273 5000 : if(!piv || !cr8r_kd_check_layer(self, &ft, a, b)){
274 0 : __builtin_trap();
275 : }
276 : #endif
277 : // increment depth
278 5000 : ++*(uint64_t*)&ft.super.base.data;
279 5000 : if(!cr8r_kd_ify(self, &ft, mid_idx + 1, b)){
280 : return 0;
281 : }
282 : b = mid_idx;
283 : }
284 : return 1;
285 : }
286 :
287 250250 : cr8r_walk_decision cr8r_kd_walk_r(cr8r_vec *self, const cr8r_kd_ft *_ft, void *bounds, cr8r_kdvisitor visitor, void *data, uint64_t a, uint64_t b){
288 250250 : cr8r_kd_ft ft = *_ft;
289 250250 : char sub0[ft.bounds_size];
290 250250 : char sub1[ft.bounds_size];
291 500250 : while(b > a){
292 250000 : uint64_t mid_idx = (a + b)/2;
293 250000 : void *ent = self->buf + mid_idx*ft.super.base.size;
294 250000 : cr8r_walk_decision decision = visitor(&ft, bounds, ent, data);
295 250000 : if(decision == CR8R_WALK_STOP){
296 : return decision;
297 250000 : }else if(decision == CR8R_WALK_SKIP_CHILDREN){
298 : return CR8R_WALK_CONTINUE;
299 : }
300 250000 : ft.split(&ft, bounds, ent, sub0, sub1);
301 : // increment depth
302 250000 : ++*(uint64_t*)&ft.super.base.data;
303 250000 : decision = cr8r_kd_walk_r(self, &ft, sub0, visitor, data, a, mid_idx);
304 250000 : if(decision == CR8R_WALK_STOP){
305 : return decision;
306 : }
307 250000 : a = mid_idx + 1;
308 250000 : memcpy(bounds, sub1, ft.bounds_size);
309 : }
310 : return CR8R_WALK_CONTINUE;
311 : }
312 :
313 250 : void cr8r_kd_walk(cr8r_vec *self, const cr8r_kd_ft *ft, const void *_bounds, cr8r_kdvisitor visitor, void *data){
314 250 : char bounds[ft->bounds_size];
315 250 : memcpy(bounds, _bounds, ft->bounds_size);
316 250 : cr8r_kd_walk_r(self, ft, bounds, visitor, data, 0, self->len);
317 250 : }
318 :
319 500000 : inline static cr8r_walk_decision k_closest_visitor(cr8r_kd_ft *ft, const void *bounds, void *ent, void *_data){
320 500000 : cr8r_kd_k_closest_state *data = _data;
321 500000 : char tmp[ft->super.base.size];
322 500000 : if(data->ents->len < data->k){
323 25000 : cr8r_mmheap_push(data->ents, &data->ft.super, ent);
324 : }else{
325 475000 : cr8r_mmheap_pushpop_max(data->ents, &data->ft.super, ent, tmp);
326 475000 : data->max_sqdist = ft->sqdist(ft, data->pt, cr8r_mmheap_peek_max(data->ents, &data->ft.super));
327 : }
328 500000 : if(isinf(data->max_sqdist) || data->max_sqdist > ft->min_sqdist(ft, bounds, ent)){
329 500000 : return CR8R_WALK_CONTINUE;
330 : }
331 : return CR8R_WALK_SKIP_CHILDREN;
332 : }
333 :
334 8316686 : int cr8r_default_cmp_kd_kcs_pt_dist(const cr8r_base_ft *_ft, const void *a, const void *b){
335 8316686 : const cr8r_kd_ft *ft = (const cr8r_kd_ft*)_ft;
336 8316686 : const cr8r_kd_k_closest_state *data = CR8R_OUTER(ft, cr8r_kd_k_closest_state, ft);
337 8316686 : double a_sqdist = ft->sqdist(ft, data->pt, a);
338 8316686 : double b_sqdist = ft->sqdist(ft, data->pt, b);
339 8316686 : if(a_sqdist < b_sqdist){
340 : return -1;
341 3722987 : }else if(a_sqdist > b_sqdist){
342 3722961 : return 1;
343 : }
344 : return 0;
345 : }
346 :
347 250 : bool cr8r_kd_k_closest(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
348 250 : cr8r_vec_clear(out, &ft->super);
349 250 : if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){
350 : return false;
351 : }
352 250 : cr8r_kd_k_closest_state data = {
353 : .ents = out,
354 : .ft = *ft,
355 : .pt = pt,
356 : .k = k,
357 : .max_sqdist = INFINITY
358 : };
359 250 : data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist;
360 250 : cr8r_kd_walk(self, ft, bounds, k_closest_visitor, &data);
361 250 : return true;
362 : }
363 :
364 250 : bool cr8r_kd_k_closest_naive(cr8r_vec *self, cr8r_kd_ft *ft, const void *bounds, const void *pt, uint64_t k, cr8r_vec *out){
365 250 : cr8r_vec_clear(out, &ft->super);
366 250 : if(!cr8r_vec_ensure_cap(out, &ft->super, k + 1)){
367 : return false;
368 : }
369 250 : cr8r_kd_k_closest_state data = {
370 : .ents = out,
371 : .ft = *ft,
372 : .pt = pt,
373 : .k = k,
374 : .max_sqdist = INFINITY
375 : };
376 250 : data.ft.super.cmp = cr8r_default_cmp_kd_kcs_pt_dist;
377 250250 : for(uint64_t i = 0; i < self->len; ++i){
378 250000 : k_closest_visitor(ft, bounds, self->buf + i*ft->super.base.size, &data);
379 : }
380 : return true;
381 : }
382 :
383 : cr8r_kd_ft cr8r_kdft_s2i64 = {
384 : .super.base.size = 3*sizeof(int64_t),
385 : .super.new_size = cr8r_default_new_size,
386 : .super.resize = cr8r_default_resize,
387 : .super.cmp = cmp_depth_i64sp,
388 : .super.swap = cr8r_default_swap,
389 : .dim = 2,
390 : .bounds_size = 6*sizeof(int64_t),
391 : .split = split_i64sp,
392 : .update = update_i64sp,
393 : .min_sqdist = min_sqdist_i64sp,
394 : .sqdist = sqdist_i64sp
395 : };
396 :
397 : cr8r_kd_ft cr8r_kdft_c3i64 = {
398 : .super.base.size = 3*sizeof(int64_t),
399 : .super.new_size = cr8r_default_new_size,
400 : .super.resize = cr8r_default_resize,
401 : .super.cmp = cmp_depth_i64cu,
402 : .super.swap = cr8r_default_swap,
403 : .dim = 3,
404 : .bounds_size = 6*sizeof(int64_t),
405 : .split = split_i64cu,
406 : .update = update_i64cu,
407 : .min_sqdist = min_sqdist_i64cu,
408 : .sqdist = sqdist_i64cu
409 : };
410 :
|