1
use std::{array, cmp::{max, Ordering}};
2

            
3
use num_traits::NumRef;
4
use rand::{distributions::{uniform::SampleUniform, Distribution, Uniform}, Rng};
5

            
6
use crate::{KdPoint, KdRegion};
7

            
8
/// Represents a point in N-dimensional Euclidean space whose coordinates are numeric type T
9
/// Wrapping CuPoint<i64, 3> or CuPoint<f64, 3> is a good way to get started quickly if you
10
/// don't need a topologically exotic implementation
11
#[derive(Clone, Debug, PartialEq)]
12
pub struct CuPoint<T, const N: usize>
13
where T: Ord + Clone + NumRef {
14
    buf: [T; N]
15
}
16

            
17
/// Represents an axis aligned cuboid region in N-dimensional Euclidean space whose
18
/// coordinates are numeric type T
19
#[derive(Clone, Debug)]
20
pub struct CuRegion<T, const N: usize>
21
where T: Ord + Clone + NumRef {
22
    pub start: CuPoint<T, N>,
23
    pub end: CuPoint<T, N>
24
}
25

            
26
impl<T, const N: usize> KdPoint for CuPoint<T, N>
27
where T: Ord + Clone + NumRef {
28
    type Distance = T;
29
283984
    fn cmp(&self, other: &Self, layer: usize) -> Ordering {
30
283984
        let idx = layer%N;
31
283984
        if N == 0 {
32
            Ordering::Equal
33
        } else {
34
283984
            self.buf[idx].cmp(&other.buf[idx])
35
        }
36
283984
    }
37

            
38
31689092
    fn sqdist(&self, other: &Self) -> Self::Distance {
39
31689092
        let mut a = T::zero();
40
126756368
        for i in 0..N {
41
95067276
            let (x, y) = (&self.buf[i], &other.buf[i]);
42
            // compute absolute difference between x and y in a really annoying way because generic
43
            // math is annoying and we can't just call x.abs_diff(y)
44
95067276
            let d = if x > y { x.clone() - y } else { y.clone() - x };
45
95067276
            a = a + d.clone()*&d;
46
        }
47
31689092
        a
48
31689092
    }
49
}
50

            
51
/// Generate a random point in a square/cube/etc.
52
/// Given a Uniform distribution sampling from a range, this adds the ability to
53
/// randomly generate CuPoints whose coordinates are iid (independent and identically distributed)
54
/// from that range
55
impl<T, const N: usize> Distribution<CuPoint<T, N>> for Uniform<T>
56
where T: Ord + Clone + NumRef + SampleUniform {
57
10500
    fn sample<R>(&self, rng: &mut R) -> CuPoint<T, N> where R: Rng + ?Sized {
58
36750
        CuPoint{buf: array::from_fn(|_|self.sample(rng))}
59
10500
    }
60
}
61

            
62
/// Generate a default point (all coordinates zero)
63
impl<T, const N: usize> Default for CuPoint<T, N>
64
where T: Ord + Clone + NumRef {
65
    fn default() -> Self {
66
        Self{buf: array::from_fn(|_|T::zero())}
67
    }
68
}
69

            
70
impl<T: Ord + Copy + NumRef, const N: usize> Copy for CuPoint<T, N> {}
71
impl<T: Ord + Clone + NumRef, const N: usize> Eq for CuPoint<T, N> {}
72

            
73
impl<T, const N: usize> CuPoint<T, N>
74
where T: Ord + Clone + NumRef {
75
    /// make a point with a given value
76
    pub fn make(buf: [T; N]) -> Self {
77
        Self{buf}
78
    }
79

            
80
	/// get readonly access to the buffer
81
40
	pub fn view(&self) -> &[T; N] {
82
40
		&self.buf
83
40
	}
84
	
85
	/// consume the point to mutably access the buffer
86
	pub fn extract(self) -> [T; N] {
87
		self.buf
88
	}
89
}
90

            
91
impl<T, const N: usize> From<[T; N]> for CuPoint<T, N>
92
where T: Ord + Clone + NumRef {
93
    fn from(buf: [T; N]) -> Self {
94
        Self{buf}
95
    }
96
}
97

            
98
impl<T, const N: usize> KdRegion for CuRegion<T, N>
99
where T: Ord + Clone + NumRef {
100
    type Point = CuPoint<T, N>;
101

            
102
500000
    fn split(&self, point: &Self::Point, layer: usize) -> (Self, Self) {
103
500000
        let mut sub0 = self.clone();
104
500000
        let mut sub1 = self.clone();
105
500000
        let idx = layer%N;
106
500000
        let split_coord = &point.buf[idx];
107
500000
        sub0.end.buf[idx].clone_from(split_coord);
108
500000
        sub1.start.buf[idx].clone_from(split_coord);
109
500000
        (sub0, sub1)
110
500000
    }
111

            
112
    fn min_sqdist(&self, point: &Self::Point) -> T {
113
        (&self.start.buf).into_iter().zip(&self.end.buf).zip(&point.buf).fold(T::zero(), |a,((l, r), x)|{
114
            let d = if x < l {
115
                l.clone() - x
116
            } else if r < x {
117
                x.clone() - r
118
            } else {
119
                return a
120
            };
121
            a + d.clone()*d
122
        })
123
    }
124

            
125
    fn max_sqdist(&self, point: &Self::Point) -> Option<T> {
126
        Some((&self.start.buf).into_iter().zip(&self.end.buf).zip(&point.buf).fold(T::zero(), |a,((l, r), x)|{
127
            let d = if x < l {
128
                r.clone() - x
129
            } else if r < x {
130
                x.clone() - l
131
            } else {
132
                max(r.clone() - x, x.clone() - l)
133
            };
134
            a + d.clone()*d
135
        }))
136
    }
137

            
138
    fn might_overlap(&self, other: &Self) -> bool {
139
        (&self.start.buf).into_iter().zip(&self.end.buf).zip((&other.start.buf).into_iter().zip(&other.end.buf)).all(|((a,b),(c,d))|{
140
            !(b < c || d < a)
141
        })
142
    }
143

            
144
    fn is_superset(&self, other: &Self) -> bool {
145
        (&self.start.buf).into_iter().zip(&self.end.buf).zip((&other.start.buf).into_iter().zip(&other.end.buf)).all(|((a,b),(c,d))|{
146
            a <= c && d <= b
147
        })
148
    }
149

            
150
9990
    fn extend(&mut self, point: &Self::Point) {
151
29970
        for (i, x) in (&point.buf).into_iter().enumerate() {
152
29970
            if x < &self.start.buf[i] {
153
176
                self.start.buf[i].clone_from(x);
154
29794
            } else if x > &self.end.buf[i] {
155
202
                self.end.buf[i].clone_from(x);
156
29592
            }
157
        }
158
9990
    }
159
    
160
10
    fn single_point(point: &Self::Point) -> Self {
161
10
        Self{start: point.clone(), end: point.clone()}
162
10
    }
163
}
164

            
165
impl<T: Ord + Copy + NumRef, const N: usize> Copy for CuRegion<T, N> {}
166

            
167
#[cfg(test)]
168
mod tests {
169
    use crate::kdtree::{KdTree, QueryOptions};
170

            
171
    use super::*;
172

            
173
    const NUM_POINTS: usize = 1000;
174
    const BOX_SIZE: i64 = 2000;
175
    const KCS_SIZE: i64 = 2200;
176
    const KCS_COUNT: usize = 50;
177
    const KCS_TRIALS: usize = 50;
178
    const KD_TRIALS: usize = 5;
179

            
180
10
	fn get_bounds<const N: usize>(kdt: &KdTree::<CuRegion<i64, N>>) -> Option<CuRegion<i64, N>> {
181
10
		let mut it = kdt.iter_points();
182
10
		let mut res = it.next().map(|p|CuRegion{start:p.clone(), end:p.clone()})?;
183
10000
		for point in it {
184
39960
			for i in 0..N {
185
29970
				let x = &point.buf[i];
186
29970
				if x < &res.start.buf[i] {
187
116
					res.start.buf[i].clone_from(x)
188
29854
				} else if x > &res.end.buf[i] {
189
914
					res.end.buf[i].clone_from(x)
190
28940
				}
191
			}
192
		}
193
10
		res.into()
194
10
	}
195
	
196
    #[test]
197
2
    fn pointcloud() {
198
2
        let mut rng = rand::thread_rng();
199
2
        let box_dist = Uniform::new_inclusive(-BOX_SIZE/2, BOX_SIZE/2);
200
2
        let kcs_dist = Uniform::new_inclusive(-KCS_SIZE/2, KCS_SIZE/2);
201
12
        for _ in 0..KD_TRIALS {
202
10
            eprintln!("Generating {0} random lattice points in [-{1}, {1}]^3", NUM_POINTS, BOX_SIZE/2);
203
10
            let mut points: Vec<(CuPoint<i64, 3>, ())> = Vec::new();
204
10010
            for _ in 0..NUM_POINTS {
205
10000
                points.push((box_dist.sample(&mut rng), ()))
206
            }
207
10
            eprintln!("Checking bounds of points");
208
10
            let kdt: KdTree<_> = points.into_iter().collect();
209
10
            let bounds = get_bounds(&kdt);
210
10
            match (&bounds, &kdt.bounds) {
211
10
                (Some(CuRegion{start: a, end: b}), Some(CuRegion{start: c, end: d})) =>
212
10
                    if a.view() != c.view() || b.view() != d.view() {
213
10
                        panic!("Bounds did not match!")},
214
                _ => panic!("Failed to get bounds!")
215
            }
216
10
            if !kdt.check_tree() {
217
                panic!("KD Tree built wrong!")
218
10
            }
219
510
            for _ in 0..KCS_TRIALS {
220
500
                let point: CuPoint<i64, 3> = kcs_dist.sample(&mut rng);
221
500
                eprintln!("Getting {} closest points to {:?}", KCS_COUNT, &point);
222
500
                let mut res: Vec<_> = kdt.k_closest(&point, KCS_COUNT, QueryOptions::ALL_NO_TIES).into();
223
500
                let mut res_naive: Vec<_> = kdt.k_closest_naive(&point, KCS_COUNT).into();
224
500
                if res.len() != KCS_COUNT || res_naive.len() != KCS_COUNT {
225
                    panic!("K Closest and/or K Closest naive failed to get {} points!", KCS_COUNT)
226
500
                }
227
302034
                res.sort_unstable_by_key(|(p,_)|point.sqdist(p));
228
302606
                res_naive.sort_unstable_by_key(|(p,_)|point.sqdist(p));
229
25250
                if res.into_iter().zip(res_naive).any(|((o,_), (e,_))|point.sqdist(o) != point.sqdist(e)) {
230
                    panic!("K Closest and K Closest naive did not get the same sets of points!")
231
500
                }
232
            }
233
        }
234
2
    }
235
}
236