Dynotree
Loading...
Searching...
No Matches
KDTree.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <cmath>
5#include <cwchar>
6#include <limits>
7#include <queue>
8#include <set>
9#include <vector>
10
11#include <eigen3/Eigen/Core>
12#include <eigen3/Eigen/Dense>
13
14#include "StateSpace.h"
16
17namespace dynotree {
18
19template <class Id, int Dimensions, std::size_t BucketSize = 32,
20 typename Scalar = double,
21 typename StateSpace = Rn<Scalar, Dimensions>>
22class KDTree {
23private:
24 struct Node;
25 std::vector<Node> m_nodes;
26 std::set<std::size_t> waitingForSplit;
27 StateSpace state_space;
28
29public:
30 using scalar_t = Scalar;
31 using id_t = Id;
32 using point_t = Eigen::Matrix<Scalar, Dimensions, 1>;
33 using cref_t = const Eigen::Ref<const Eigen::Matrix<Scalar, Dimensions, 1>> &;
34 using ref_t = Eigen::Ref<Eigen::Matrix<Scalar, Dimensions, 1>>;
35 using state_space_t = StateSpace;
36 int m_dimensions = Dimensions;
37 static const std::size_t bucketSize = BucketSize;
39
40 StateSpace &getStateSpace() { return state_space; }
41
42 KDTree() = default;
43
44 void init_tree(int runtime_dimension = -1,
45 const StateSpace &t_state_space = StateSpace()) {
46 state_space = t_state_space;
47 if constexpr (Dimensions == Eigen::Dynamic) {
48 assert(runtime_dimension > 0);
49 m_dimensions = runtime_dimension;
50 m_nodes.emplace_back(BucketSize, m_dimensions);
51 } else {
52 m_nodes.emplace_back(BucketSize, -1);
53 }
54 }
55
56 size_t size() const { return m_nodes[0].m_entries; }
57
58 void addPoint(const point_t &x, const Id &id, bool autosplit = true) {
59 std::size_t addNode = 0;
60
61 assert(m_dimensions > 0);
62 while (m_nodes[addNode].m_splitDimension != m_dimensions) {
63 m_nodes[addNode].expandBounds(x);
64 if (x[m_nodes[addNode].m_splitDimension] <
65 m_nodes[addNode].m_splitValue) {
66 addNode = m_nodes[addNode].m_children.first;
67 } else {
68 addNode = m_nodes[addNode].m_children.second;
69 }
70 }
71 m_nodes[addNode].add(PointId{x, id});
72
73 if (m_nodes[addNode].shouldSplit() &&
74 m_nodes[addNode].m_entries % BucketSize == 0) {
75 if (autosplit) {
76 split(addNode);
77 } else {
78 waitingForSplit.insert(addNode);
79 }
80 }
81 }
82
84 std::vector<std::size_t> searchStack(waitingForSplit.begin(),
85 waitingForSplit.end());
86 waitingForSplit.clear();
87 while (searchStack.size() > 0) {
88 std::size_t addNode = searchStack.back();
89 searchStack.pop_back();
90 if (m_nodes[addNode].m_splitDimension == m_dimensions &&
91 m_nodes[addNode].shouldSplit() && split(addNode)) {
92 searchStack.push_back(m_nodes[addNode].m_children.first);
93 searchStack.push_back(m_nodes[addNode].m_children.second);
94 }
95 }
96 }
97
98 struct DistanceId {
99 Scalar distance;
100 Id id;
101 inline bool operator<(const DistanceId &dp) const {
102 return distance < dp.distance;
103 }
104 };
105
106 std::vector<DistanceId> searchKnn(const point_t &x,
107 std::size_t maxPoints) const {
108 return searcher().search(x, std::numeric_limits<Scalar>::max(), maxPoints,
109 state_space);
110 }
111
112 std::vector<DistanceId> searchBall(const point_t &x, Scalar maxRadius) const {
113 return searcher().search(
114 x, maxRadius, std::numeric_limits<std::size_t>::max(), state_space);
115 }
116
117 std::vector<DistanceId>
118 searchCapacityLimitedBall(const point_t &x, Scalar maxRadius,
119 std::size_t maxPoints) const {
120 return searcher().search(x, maxRadius, maxPoints, state_space);
121 }
122
123 DistanceId search(const point_t &x) const {
124 DistanceId result;
125 result.distance = std::numeric_limits<Scalar>::infinity();
126
127 if (m_nodes[0].m_entries > 0) {
128 std::vector<std::size_t> searchStack;
129 searchStack.reserve(
130 1 +
131 std::size_t(1.5 * std::log2(1 + m_nodes[0].m_entries / BucketSize)));
132 searchStack.push_back(0);
133
134 while (searchStack.size() > 0) {
135 std::size_t nodeIndex = searchStack.back();
136 searchStack.pop_back();
137 const Node &node = m_nodes[nodeIndex];
138 if (result.distance > node.distance_to_rectangle(x, state_space)) {
139 if (node.m_splitDimension == m_dimensions) {
140 for (const auto &lp : node.m_locationId) {
141 // Allow to have inactive nodes in the tree
142 if (!lp.active)
143 continue;
144 Scalar nodeDist = state_space.distance(x, lp.x);
145 if (nodeDist < result.distance) {
146 result = DistanceId{nodeDist, lp.id};
147 }
148 }
149 } else {
150 node.queueChildren(x, searchStack);
151 }
152 }
153 }
154 }
155 return result;
156 }
157
158 void set_inactive(const point_t &x) {
159 DistanceId result;
160 result.distance = std::numeric_limits<Scalar>::infinity();
161
162 bool found = false;
163 if (m_nodes[0].m_entries > 0) {
164 std::vector<std::size_t> searchStack;
165 searchStack.reserve(
166 1 +
167 std::size_t(1.5 * std::log2(1 + m_nodes[0].m_entries / BucketSize)));
168 searchStack.push_back(0);
169
170 while (!found && searchStack.size() > 0) {
171 std::size_t nodeIndex = searchStack.back();
172 searchStack.pop_back();
173 Node &node = m_nodes[nodeIndex];
174 if (result.distance > node.distance_to_rectangle(x, state_space)) {
175 if (node.m_splitDimension == m_dimensions) {
176 for (auto &lp : node.m_locationId) {
177 // Allow to have inactive nodes in the tree
178 if (!lp.active)
179 continue;
180 Scalar nodeDist = state_space.distance(x, lp.x);
181 if (nodeDist < result.distance) {
182 result = DistanceId{nodeDist, lp.id};
183 if (result.distance < 1e-8) {
184 found = true;
185 lp.active = false;
186 break;
187 }
188 }
189 }
190 } else {
191 node.queueChildren(x, searchStack);
192 }
193 }
194 }
195 }
197 // return result;
198 }
199
200 class Searcher {
201 public:
202 Searcher(const tree_t &tree) : m_tree(tree) {}
203 Searcher(const Searcher &searcher) : m_tree(searcher.m_tree) {}
204
205 // NB! this method is not const. Do not call this on same instance from
206 // different threads simultaneously.
207 const std::vector<DistanceId> &search(const point_t &x, Scalar maxRadius,
208 std::size_t maxPoints,
209 const StateSpace &state_space) {
210 // clear results from last time
211 m_results.clear();
212
213 // reserve capacities
214 m_searchStack.reserve(
215 1 + std::size_t(1.5 * std::log2(1 + m_tree.m_nodes[0].m_entries /
216 BucketSize)));
217 if (m_prioqueueCapacity < maxPoints &&
218 maxPoints < m_tree.m_nodes[0].m_entries) {
219 std::vector<DistanceId> container;
220 container.reserve(maxPoints);
221 m_prioqueue = std::priority_queue<DistanceId, std::vector<DistanceId>>(
222 std::less<DistanceId>(), std::move(container));
223 m_prioqueueCapacity = maxPoints;
224 }
225
226 m_tree.searchCapacityLimitedBall(x, maxRadius, maxPoints, m_searchStack,
227 m_prioqueue, m_results, state_space);
228
229 m_prioqueueCapacity = std::max(m_prioqueueCapacity, m_results.size());
230 return m_results;
231 }
232
233 private:
234 const tree_t &m_tree;
235
236 std::vector<std::size_t> m_searchStack;
237 std::priority_queue<DistanceId, std::vector<DistanceId>> m_prioqueue;
238 std::size_t m_prioqueueCapacity = 0;
239 std::vector<DistanceId> m_results;
240 };
241
242 // NB! returned class has no const methods. Get one instance per thread!
243 Searcher searcher() const { return Searcher(*this); }
244
245private:
246 struct PointId {
247 point_t x;
248 Id id;
249 bool active = true;
250 };
251 std::vector<PointId> m_bucketRecycle;
252
254 const point_t &x, Scalar maxRadius, std::size_t maxPoints,
255 std::vector<std::size_t> &searchStack,
256 std::priority_queue<DistanceId, std::vector<DistanceId>> &prioqueue,
257 std::vector<DistanceId> &results, const StateSpace &state_space) const {
258 std::size_t numSearchPoints = std::min(maxPoints, m_nodes[0].m_entries);
259
260 if (numSearchPoints > 0) {
261 searchStack.push_back(0);
262 while (searchStack.size() > 0) {
263 std::size_t nodeIndex = searchStack.back();
264 searchStack.pop_back();
265 const Node &node = m_nodes[nodeIndex];
266 Scalar minDist = node.distance_to_rectangle(x, state_space);
267 if (maxRadius > minDist && (prioqueue.size() < numSearchPoints ||
268 prioqueue.top().distance > minDist)) {
269 if (node.m_splitDimension == m_dimensions) {
270 node.searchCapacityLimitedBall(x, maxRadius, numSearchPoints,
271 prioqueue, state_space);
272 } else {
273 node.queueChildren(x, searchStack);
274 }
275 }
276 }
277
278 results.reserve(prioqueue.size());
279 while (prioqueue.size() > 0) {
280 results.push_back(prioqueue.top());
281 prioqueue.pop();
282 }
283 std::reverse(results.begin(), results.end());
284 }
285 }
286
287 bool split(std::size_t index) {
288 if (m_nodes.capacity() < m_nodes.size() + 2) {
289 m_nodes.reserve((m_nodes.capacity() + 1) * 2);
290 }
291 Node &splitNode = m_nodes[index];
292 splitNode.m_splitDimension = m_dimensions;
293 Scalar width(0);
294 state_space.choose_split_dimension(splitNode.m_lb, splitNode.m_ub,
295 splitNode.m_splitDimension, width);
296
297 if (splitNode.m_splitDimension == m_dimensions) {
298 return false;
299 }
300
301 std::vector<Scalar> splitDimVals;
302 splitDimVals.reserve(splitNode.m_entries);
303 for (const auto &lp : splitNode.m_locationId) {
304 splitDimVals.push_back(lp.x[splitNode.m_splitDimension]);
305 }
306 std::nth_element(splitDimVals.begin(),
307 splitDimVals.begin() + splitDimVals.size() / 2 + 1,
308 splitDimVals.end());
309 std::nth_element(splitDimVals.begin(),
310 splitDimVals.begin() + splitDimVals.size() / 2,
311 splitDimVals.begin() + splitDimVals.size() / 2 + 1);
312 splitNode.m_splitValue = (splitDimVals[splitDimVals.size() / 2] +
313 splitDimVals[splitDimVals.size() / 2 + 1]) /
314 Scalar(2);
315
316 splitNode.m_children = std::make_pair(m_nodes.size(), m_nodes.size() + 1);
317 std::size_t entries = splitNode.m_entries;
318 m_nodes.emplace_back(m_bucketRecycle, entries, m_dimensions);
319 Node &leftNode = m_nodes.back();
320 m_nodes.emplace_back(entries, m_dimensions);
321 Node &rightNode = m_nodes.back();
322
323 for (const auto &lp : splitNode.m_locationId) {
324 if (lp.x[splitNode.m_splitDimension] < splitNode.m_splitValue) {
325 leftNode.add(lp);
326 } else {
327 rightNode.add(lp);
328 }
329 }
330
331 if (leftNode.m_entries ==
332 0) // points with equality to splitValue go in rightNode
333 {
334 splitNode.m_splitValue = 0;
335 splitNode.m_splitDimension = m_dimensions;
336 splitNode.m_children = std::pair<std::size_t, std::size_t>(0, 0);
337 std::swap(rightNode.m_locationId, m_bucketRecycle);
338 m_nodes.pop_back();
339 m_nodes.pop_back();
340 return false;
341 } else {
342 splitNode.m_locationId.clear();
343 // if it was a standard sized bucket, recycle the memory to reduce
344 // allocator pressure otherwise clear the memory used by the bucket
345 // since it is a branch not a leaf anymore
346 if (splitNode.m_locationId.capacity() == BucketSize) {
347 std::swap(splitNode.m_locationId, m_bucketRecycle);
348 } else {
349 std::vector<PointId> empty;
350 std::swap(splitNode.m_locationId, empty);
351 }
352 return true;
353 }
354 }
355
356 struct Node {
357 Node(std::size_t capacity, int runtime_dimension = -1) {
358 init(capacity, runtime_dimension);
359 }
360
361 Node(std::vector<PointId> &recycle, std::size_t capacity,
362 int runtime_dimension) {
363 std::swap(m_locationId, recycle);
364 init(capacity, runtime_dimension);
365 }
366
367 void init(std::size_t capacity, int runtime_dimension) {
368
369 if constexpr (Dimensions == Eigen::Dynamic) {
370 assert(runtime_dimension > 0);
371 m_lb.resize(runtime_dimension);
372 m_ub.resize(runtime_dimension);
373 m_splitDimension = runtime_dimension;
374 }
375
376 m_lb.setConstant(std::numeric_limits<Scalar>::max());
377 m_ub.setConstant(std::numeric_limits<Scalar>::lowest());
378 m_locationId.reserve(std::max(BucketSize, capacity));
379 }
380
381 void expandBounds(const point_t &x) {
382 m_lb = m_lb.cwiseMin(x);
383 m_ub = m_ub.cwiseMax(x);
384 m_entries++;
385 }
386
387 void add(const PointId &lp) {
388 expandBounds(lp.x);
389 m_locationId.push_back(lp);
390 }
391
392 bool shouldSplit() const { return m_entries >= BucketSize; }
393
394 void searchCapacityLimitedBall(const point_t &x, Scalar maxRadius,
395 std::size_t K,
396 std::priority_queue<DistanceId> &results,
397 const StateSpace &state_space) const {
398
399 std::size_t i = 0;
400
401 // this fills up the queue if it isn't full yet
402 for (; results.size() < K && i < m_entries; i++) {
403 const auto &lp = m_locationId[i];
404 Scalar distance = state_space.distance(x, lp.x);
405 if (distance < maxRadius) {
406 results.emplace(DistanceId{distance, lp.id});
407 }
408 }
409
410 // this adds new things to the queue once it is full
411 for (; i < m_entries; i++) {
412 const auto &lp = m_locationId[i];
413 Scalar distance = state_space.distance(x, lp.x);
414 if (distance < maxRadius && distance < results.top().distance) {
415 results.pop();
416 results.emplace(DistanceId{distance, lp.id});
417 }
418 }
419 }
420
421 void queueChildren(const point_t &x,
422 std::vector<std::size_t> &searchStack) const {
423 if (x[m_splitDimension] < m_splitValue) {
424 searchStack.push_back(m_children.second);
425 searchStack.push_back(m_children.first); // left is popped first
426 } else {
427 searchStack.push_back(m_children.first);
428 searchStack.push_back(m_children.second); // right is popped first
429 }
430 }
431
432 Scalar distance_to_rectangle(const point_t &x,
433 const StateSpace &distance) const {
434 return distance.distance_to_rectangle(x, m_lb, m_ub);
435 }
436
437 std::size_t m_entries = 0;
438
439 int m_splitDimension = Dimensions;
440 Scalar m_splitValue = 0;
441
442 // struct Range {
443 // Scalar min, max;
444 // };
445
446 // std::array<Range, Dimensions> m_bounds; /// bounding box of this node
456 Eigen::Matrix<Scalar, Dimensions, 1> m_lb;
457 Eigen::Matrix<Scalar, Dimensions, 1> m_ub;
458
459 std::pair<std::size_t, std::size_t>
460 m_children;
461 std::vector<PointId> m_locationId;
462 };
463};
464
465} // namespace dynotree
Definition KDTree.h:200
Searcher(const Searcher &searcher)
Definition KDTree.h:203
const std::vector< DistanceId > & search(const point_t &x, Scalar maxRadius, std::size_t maxPoints, const StateSpace &state_space)
Definition KDTree.h:207
Searcher(const tree_t &tree)
Definition KDTree.h:202
Definition KDTree.h:22
DistanceId search(const point_t &x) const
Definition KDTree.h:123
void set_inactive(const point_t &x)
Definition KDTree.h:158
size_t size() const
Definition KDTree.h:56
StateSpace state_space_t
Definition KDTree.h:35
void addPoint(const point_t &x, const Id &id, bool autosplit=true)
Definition KDTree.h:58
Eigen::Matrix< Scalar, Dimensions, 1 > point_t
Definition KDTree.h:32
void init_tree(int runtime_dimension=-1, const StateSpace &t_state_space=StateSpace())
Definition KDTree.h:44
Scalar scalar_t
Definition KDTree.h:30
Id id_t
Definition KDTree.h:31
Searcher searcher() const
Definition KDTree.h:243
const Eigen::Ref< const Eigen::Matrix< Scalar, Dimensions, 1 > > & cref_t
Definition KDTree.h:33
KDTree()=default
std::vector< DistanceId > searchCapacityLimitedBall(const point_t &x, Scalar maxRadius, std::size_t maxPoints) const
Definition KDTree.h:118
void splitOutstanding()
Definition KDTree.h:83
static const std::size_t bucketSize
Definition KDTree.h:37
int m_dimensions
Definition KDTree.h:36
StateSpace & getStateSpace()
Definition KDTree.h:40
std::vector< DistanceId > searchKnn(const point_t &x, std::size_t maxPoints) const
Definition KDTree.h:106
Eigen::Ref< Eigen::Matrix< Scalar, Dimensions, 1 > > ref_t
Definition KDTree.h:34
std::vector< DistanceId > searchBall(const point_t &x, Scalar maxRadius) const
Definition KDTree.h:112
#define CHECK_PRETTY_DYNOTREE__(condition)
Definition dynotree_macros.h:41
Definition dynotree_macros.h:9
Definition KDTree.h:98
Scalar distance
Definition KDTree.h:99
bool operator<(const DistanceId &dp) const
Definition KDTree.h:101
Id id
Definition KDTree.h:100