25 std::vector<Node> m_nodes;
26 std::set<std::size_t> waitingForSplit;
27 StateSpace state_space;
32 using point_t = Eigen::Matrix<Scalar, Dimensions, 1>;
33 using cref_t =
const Eigen::Ref<const Eigen::Matrix<Scalar, Dimensions, 1>> &;
34 using ref_t = Eigen::Ref<Eigen::Matrix<Scalar, Dimensions, 1>>;
45 const StateSpace &t_state_space = StateSpace()) {
46 state_space = t_state_space;
47 if constexpr (Dimensions == Eigen::Dynamic) {
48 assert(runtime_dimension > 0);
52 m_nodes.emplace_back(BucketSize, -1);
56 size_t size()
const {
return m_nodes[0].m_entries; }
59 std::size_t addNode = 0;
62 while (m_nodes[addNode].m_splitDimension !=
m_dimensions) {
63 m_nodes[addNode].expandBounds(x);
64 if (x[m_nodes[addNode].m_splitDimension] <
65 m_nodes[addNode].m_splitValue) {
66 addNode = m_nodes[addNode].m_children.first;
68 addNode = m_nodes[addNode].m_children.second;
71 m_nodes[addNode].add(PointId{x,
id});
73 if (m_nodes[addNode].shouldSplit() &&
74 m_nodes[addNode].m_entries % BucketSize == 0) {
78 waitingForSplit.insert(addNode);
84 std::vector<std::size_t> searchStack(waitingForSplit.begin(),
85 waitingForSplit.end());
86 waitingForSplit.clear();
87 while (searchStack.size() > 0) {
88 std::size_t addNode = searchStack.back();
89 searchStack.pop_back();
91 m_nodes[addNode].shouldSplit() && split(addNode)) {
92 searchStack.push_back(m_nodes[addNode].m_children.first);
93 searchStack.push_back(m_nodes[addNode].m_children.second);
107 std::size_t maxPoints)
const {
108 return searcher().
search(x, std::numeric_limits<Scalar>::max(), maxPoints,
114 x, maxRadius, std::numeric_limits<std::size_t>::max(), state_space);
117 std::vector<DistanceId>
119 std::size_t maxPoints)
const {
125 result.
distance = std::numeric_limits<Scalar>::infinity();
127 if (m_nodes[0].m_entries > 0) {
128 std::vector<std::size_t> searchStack;
131 std::size_t(1.5 * std::log2(1 + m_nodes[0].m_entries / BucketSize)));
132 searchStack.push_back(0);
134 while (searchStack.size() > 0) {
135 std::size_t nodeIndex = searchStack.back();
136 searchStack.pop_back();
137 const Node &node = m_nodes[nodeIndex];
138 if (result.
distance > node.distance_to_rectangle(x, state_space)) {
140 for (
const auto &lp : node.m_locationId) {
144 Scalar nodeDist = state_space.distance(x, lp.x);
150 node.queueChildren(x, searchStack);
160 result.
distance = std::numeric_limits<Scalar>::infinity();
163 if (m_nodes[0].m_entries > 0) {
164 std::vector<std::size_t> searchStack;
167 std::size_t(1.5 * std::log2(1 + m_nodes[0].m_entries / BucketSize)));
168 searchStack.push_back(0);
170 while (!found && searchStack.size() > 0) {
171 std::size_t nodeIndex = searchStack.back();
172 searchStack.pop_back();
173 Node &node = m_nodes[nodeIndex];
174 if (result.
distance > node.distance_to_rectangle(x, state_space)) {
176 for (
auto &lp : node.m_locationId) {
180 Scalar nodeDist = state_space.distance(x, lp.x);
191 node.queueChildren(x, searchStack);
208 std::size_t maxPoints,
209 const StateSpace &state_space) {
214 m_searchStack.reserve(
215 1 + std::size_t(1.5 * std::log2(1 + m_tree.m_nodes[0].m_entries /
217 if (m_prioqueueCapacity < maxPoints &&
218 maxPoints < m_tree.m_nodes[0].m_entries) {
219 std::vector<DistanceId> container;
220 container.reserve(maxPoints);
221 m_prioqueue = std::priority_queue<DistanceId, std::vector<DistanceId>>(
222 std::less<DistanceId>(), std::move(container));
223 m_prioqueueCapacity = maxPoints;
227 m_prioqueue, m_results, state_space);
229 m_prioqueueCapacity = std::max(m_prioqueueCapacity, m_results.size());
236 std::vector<std::size_t> m_searchStack;
237 std::priority_queue<DistanceId, std::vector<DistanceId>> m_prioqueue;
238 std::size_t m_prioqueueCapacity = 0;
239 std::vector<DistanceId> m_results;
251 std::vector<PointId> m_bucketRecycle;
254 const point_t &x, Scalar maxRadius, std::size_t maxPoints,
255 std::vector<std::size_t> &searchStack,
256 std::priority_queue<DistanceId, std::vector<DistanceId>> &prioqueue,
257 std::vector<DistanceId> &results,
const StateSpace &state_space)
const {
258 std::size_t numSearchPoints = std::min(maxPoints, m_nodes[0].m_entries);
260 if (numSearchPoints > 0) {
261 searchStack.push_back(0);
262 while (searchStack.size() > 0) {
263 std::size_t nodeIndex = searchStack.back();
264 searchStack.pop_back();
265 const Node &node = m_nodes[nodeIndex];
266 Scalar minDist = node.distance_to_rectangle(x, state_space);
267 if (maxRadius > minDist && (prioqueue.size() < numSearchPoints ||
268 prioqueue.top().distance > minDist)) {
270 node.searchCapacityLimitedBall(x, maxRadius, numSearchPoints,
271 prioqueue, state_space);
273 node.queueChildren(x, searchStack);
278 results.reserve(prioqueue.size());
279 while (prioqueue.size() > 0) {
280 results.push_back(prioqueue.top());
283 std::reverse(results.begin(), results.end());
287 bool split(std::size_t index) {
288 if (m_nodes.capacity() < m_nodes.size() + 2) {
289 m_nodes.reserve((m_nodes.capacity() + 1) * 2);
291 Node &splitNode = m_nodes[index];
294 state_space.choose_split_dimension(splitNode.m_lb, splitNode.m_ub,
295 splitNode.m_splitDimension, width);
301 std::vector<Scalar> splitDimVals;
302 splitDimVals.reserve(splitNode.m_entries);
303 for (
const auto &lp : splitNode.m_locationId) {
304 splitDimVals.push_back(lp.x[splitNode.m_splitDimension]);
306 std::nth_element(splitDimVals.begin(),
307 splitDimVals.begin() + splitDimVals.size() / 2 + 1,
309 std::nth_element(splitDimVals.begin(),
310 splitDimVals.begin() + splitDimVals.size() / 2,
311 splitDimVals.begin() + splitDimVals.size() / 2 + 1);
312 splitNode.m_splitValue = (splitDimVals[splitDimVals.size() / 2] +
313 splitDimVals[splitDimVals.size() / 2 + 1]) /
316 splitNode.m_children = std::make_pair(m_nodes.size(), m_nodes.size() + 1);
317 std::size_t entries = splitNode.m_entries;
318 m_nodes.emplace_back(m_bucketRecycle, entries,
m_dimensions);
319 Node &leftNode = m_nodes.back();
321 Node &rightNode = m_nodes.back();
323 for (
const auto &lp : splitNode.m_locationId) {
324 if (lp.x[splitNode.m_splitDimension] < splitNode.m_splitValue) {
331 if (leftNode.m_entries ==
334 splitNode.m_splitValue = 0;
336 splitNode.m_children = std::pair<std::size_t, std::size_t>(0, 0);
337 std::swap(rightNode.m_locationId, m_bucketRecycle);
342 splitNode.m_locationId.clear();
346 if (splitNode.m_locationId.capacity() == BucketSize) {
347 std::swap(splitNode.m_locationId, m_bucketRecycle);
349 std::vector<PointId> empty;
350 std::swap(splitNode.m_locationId, empty);
357 Node(std::size_t capacity,
int runtime_dimension = -1) {
358 init(capacity, runtime_dimension);
361 Node(std::vector<PointId> &recycle, std::size_t capacity,
362 int runtime_dimension) {
363 std::swap(m_locationId, recycle);
364 init(capacity, runtime_dimension);
367 void init(std::size_t capacity,
int runtime_dimension) {
369 if constexpr (Dimensions == Eigen::Dynamic) {
370 assert(runtime_dimension > 0);
371 m_lb.resize(runtime_dimension);
372 m_ub.resize(runtime_dimension);
373 m_splitDimension = runtime_dimension;
376 m_lb.setConstant(std::numeric_limits<Scalar>::max());
377 m_ub.setConstant(std::numeric_limits<Scalar>::lowest());
378 m_locationId.reserve(std::max(BucketSize, capacity));
381 void expandBounds(
const point_t &x) {
382 m_lb = m_lb.cwiseMin(x);
383 m_ub = m_ub.cwiseMax(x);
387 void add(
const PointId &lp) {
389 m_locationId.push_back(lp);
392 bool shouldSplit()
const {
return m_entries >= BucketSize; }
394 void searchCapacityLimitedBall(
const point_t &x, Scalar maxRadius,
396 std::priority_queue<DistanceId> &results,
397 const StateSpace &state_space)
const {
402 for (; results.size() < K && i < m_entries; i++) {
403 const auto &lp = m_locationId[i];
404 Scalar distance = state_space.distance(x, lp.x);
405 if (distance < maxRadius) {
406 results.emplace(DistanceId{distance, lp.id});
411 for (; i < m_entries; i++) {
412 const auto &lp = m_locationId[i];
413 Scalar distance = state_space.distance(x, lp.x);
414 if (distance < maxRadius && distance < results.top().distance) {
416 results.emplace(DistanceId{distance, lp.id});
421 void queueChildren(
const point_t &x,
422 std::vector<std::size_t> &searchStack)
const {
423 if (x[m_splitDimension] < m_splitValue) {
424 searchStack.push_back(m_children.second);
425 searchStack.push_back(m_children.first);
427 searchStack.push_back(m_children.first);
428 searchStack.push_back(m_children.second);
432 Scalar distance_to_rectangle(
const point_t &x,
433 const StateSpace &distance)
const {
434 return distance.distance_to_rectangle(x, m_lb, m_ub);
437 std::size_t m_entries = 0;
439 int m_splitDimension = Dimensions;
440 Scalar m_splitValue = 0;
456 Eigen::Matrix<Scalar, Dimensions, 1> m_lb;
457 Eigen::Matrix<Scalar, Dimensions, 1> m_ub;
459 std::pair<std::size_t, std::size_t>
461 std::vector<PointId> m_locationId;