All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
NearestNeighborsGNAT.h
1 /*********************************************************************
2 * Software License Agreement (BSD License)
3 *
4 * Copyright (c) 2011, Rice University
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * * Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 * * Redistributions in binary form must reproduce the above
14 * copyright notice, this list of conditions and the following
15 * disclaimer in the documentation and/or other materials provided
16 * with the distribution.
17 * * Neither the name of the Rice University nor the names of its
18 * contributors may be used to endorse or promote products derived
19 * from this software without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 * POSSIBILITY OF SUCH DAMAGE.
33 *********************************************************************/
34 
35 /* Author: Mark Moll */
36 
37 #ifndef OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
38 #define OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
39 
40 #include "ompl/datastructures/NearestNeighbors.h"
41 #include "ompl/datastructures/GreedyKCenters.h"
42 #include "ompl/util/Exception.h"
43 #include <boost/unordered_set.hpp>
44 #include <queue>
45 #include <algorithm>
46 
47 namespace ompl
48 {
49 
58  template<typename _T>
60  {
61  protected:
63  // internally, we use a priority queue for nearest neighbors, paired
64  // with their distance to the query point
65  typedef std::pair<const _T*,double> DataDist;
66  struct DataDistCompare
67  {
68  bool operator()(const DataDist& d0, const DataDist& d1)
69  {
70  return d0.second < d1.second;
71  }
72  };
73  typedef std::priority_queue<DataDist, std::vector<DataDist>, DataDistCompare> NearQueue;
74 
75  // another internal data structure is a priority queue of nodes to
76  // check next for possible nearest neighbors
77  class Node;
78  typedef std::pair<Node*,double> NodeDist;
79  struct NodeDistCompare
80  {
81  bool operator()(const NodeDist& n0, const NodeDist& n1) const
82  {
83  return (n0.second - n0.first->maxRadius_) > (n1.second - n1.first->maxRadius_);
84  }
85  };
86  typedef std::priority_queue<NodeDist, std::vector<NodeDist>, NodeDistCompare> NodeQueue;
88 
89  public:
90  NearestNeighborsGNAT(unsigned int degree = 4, unsigned int minDegree = 2,
91  unsigned int maxDegree = 6, unsigned int maxNumPtsPerLeaf = 50,
92  unsigned int removedCacheSize = 50, bool rebalancing = false)
93  : NearestNeighbors<_T>(), tree_(NULL), degree_(degree),
94  minDegree_(std::min(degree,minDegree)), maxDegree_(std::max(maxDegree,degree)),
95  maxNumPtsPerLeaf_(maxNumPtsPerLeaf), size_(0),
96  rebuildSize_(rebalancing ? maxNumPtsPerLeaf*degree : std::numeric_limits<std::size_t>::max()),
97  removedCacheSize_(removedCacheSize)
98  {
99  }
100 
101  virtual ~NearestNeighborsGNAT(void)
102  {
103  if (tree_)
104  delete tree_;
105  }
107  virtual void setDistanceFunction(const typename NearestNeighbors<_T>::DistanceFunction &distFun)
108  {
110  pivotSelector_.setDistanceFunction(distFun);
111  if (tree_)
113  }
114  virtual void clear(void)
115  {
116  if (tree_)
117  {
118  delete tree_;
119  tree_ = NULL;
120  }
121  size_ = 0;
122  removed_.clear();
123  }
124 
125  virtual void add(const _T &data)
126  {
127  if (tree_)
128  {
129  if (isRemoved(data))
131  tree_->add(*this, data);
132  }
133  else
134  {
135  tree_ = new Node(degree_, maxNumPtsPerLeaf_, data);
136  size_ = 1;
137  }
138  }
139  virtual void add(const std::vector<_T> &data)
140  {
141  if (tree_)
143  else if (data.size()>0)
144  {
145  tree_ = new Node(degree_, maxNumPtsPerLeaf_, data[0]);
146  for (unsigned int i=1; i<data.size(); ++i)
147  tree_->data_.push_back(data[i]);
148  if (tree_->needToSplit(*this))
149  tree_->split(*this);
150  }
151  size_ += data.size();
152  }
155  {
156  std::vector<_T> lst;
157  list(lst);
158  clear();
159  add(lst);
160  }
166  virtual bool remove(const _T &data)
167  {
168  if (!size_) return false;
169  NearQueue nbhQueue;
170  // find data in tree
171  bool isPivot = nearestKInternal(data, 1, nbhQueue);
172  if (*nbhQueue.top().first != data)
173  return false;
174  removed_.insert(nbhQueue.top().first);
175  size_--;
176  // if we removed a pivot or if the capacity of removed elements
177  // has been reached, we rebuild the entire GNAT
178  if (isPivot || removed_.size()>=removedCacheSize_)
180  return true;
181  }
182 
183  virtual _T nearest(const _T &data) const
184  {
185  if (size_)
186  {
187  std::vector<_T> nbh;
188  nearestK(data, 1, nbh);
189  if (!nbh.empty()) return nbh[0];
190  }
191  throw Exception("No elements found in nearest neighbors data structure");
192  }
193 
194  virtual void nearestK(const _T &data, std::size_t k, std::vector<_T> &nbh) const
195  {
196  nbh.clear();
197  if (k == 0) return;
198  if (size_)
199  {
200  NearQueue nbhQueue;
201  nearestKInternal(data, k, nbhQueue);
202  postprocessNearest(nbhQueue, nbh);
203  }
204  }
205 
206  virtual void nearestR(const _T &data, double radius, std::vector<_T> &nbh) const
207  {
208  nbh.clear();
209  if (size_)
210  {
211  NearQueue nbhQueue;
212  nearestRInternal(data, radius, nbhQueue);
213  postprocessNearest(nbhQueue, nbh);
214  }
215  }
216 
217  virtual std::size_t size(void) const
218  {
219  return size_;
220  }
221 
222  virtual void list(std::vector<_T> &data) const
223  {
224  data.clear();
225  data.reserve(size());
226  if (tree_)
227  tree_->list(*this, data);
228  }
229 
231  friend std::ostream& operator<<(std::ostream& out, const NearestNeighborsGNAT<_T>& gnat)
232  {
233  if (gnat.tree_)
234  {
235  out << *gnat.tree_;
236  if (!gnat.removed_.empty())
237  {
238  out << "Elements marked for removal:\n";
239  for (typename boost::unordered_set<const _T*>::const_iterator it = gnat.removed_.begin();
240  it != gnat.removed_.end(); it++)
241  out << **it << '\t';
242  out << std::endl;
243  }
244  }
245  return out;
246  }
247 
248  // for debugging purposes
249  void integrityCheck()
250  {
251  std::vector<_T> lst;
252  boost::unordered_set<const _T*> tmp;
253  // get all elements, including those marked for removal
254  removed_.swap(tmp);
255  list(lst);
256  // check if every element marked for removal is also in the tree
257  for (typename boost::unordered_set<const _T*>::iterator it=tmp.begin(); it!=tmp.end(); it++)
258  {
259  unsigned int i;
260  for (i=0; i<lst.size(); ++i)
261  if (lst[i]==**it)
262  break;
263  if (i == lst.size())
264  {
265  // an element marked for removal is not actually in the tree
266  std::cout << "***** FAIL!! ******\n" << *this << '\n';
267  for (unsigned int j=0; j<lst.size(); ++j) std::cout<<lst[j]<<'\t';
268  std::cout<<std::endl;
269  }
270  assert(i != lst.size());
271  }
272  // restore
273  removed_.swap(tmp);
274  // get elements in the tree with elements marked for removal purged from the list
275  list(lst);
276  if (lst.size() != size_)
277  std::cout << "#########################################\n" << *this << std::endl;
278  assert(lst.size() == size_);
279  }
280  protected:
281  typedef NearestNeighborsGNAT<_T> GNAT;
282 
284  bool isRemoved(const _T& data) const
285  {
286  return !removed_.empty() && removed_.find(&data) != removed_.end();
287  }
288 
293  bool nearestKInternal(const _T &data, std::size_t k, NearQueue& nbhQueue) const
294  {
295  bool isPivot;
296  double dist;
297  NodeDist nodeDist;
298  NodeQueue nodeQueue;
299 
300  isPivot = tree_->insertNeighborK(nbhQueue, k, tree_->pivot_, data,
302  tree_->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
303  while (nodeQueue.size() > 0)
304  {
305  dist = nbhQueue.top().second; // note the difference with nearestRInternal
306  nodeDist = nodeQueue.top();
307  nodeQueue.pop();
308  if (nbhQueue.size() == k &&
309  (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
310  nodeDist.second < nodeDist.first->minRadius_ - dist))
311  break;
312  nodeDist.first->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
313  }
314  return isPivot;
315  }
317  void nearestRInternal(const _T &data, double radius, NearQueue& nbhQueue) const
318  {
319  double dist = radius; // note the difference with nearestKInternal
320  NodeQueue nodeQueue;
321  NodeDist nodeDist;
322 
323  tree_->insertNeighborR(nbhQueue, radius, tree_->pivot_,
325  tree_->nearestR(*this, data, radius, nbhQueue, nodeQueue);
326  while (nodeQueue.size() > 0)
327  {
328  nodeDist = nodeQueue.top();
329  nodeQueue.pop();
330  if (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
331  nodeDist.second < nodeDist.first->minRadius_ - dist)
332  break;
333  nodeDist.first->nearestR(*this, data, radius, nbhQueue, nodeQueue);
334  }
335  }
338  void postprocessNearest(NearQueue& nbhQueue, std::vector<_T> &nbh) const
339  {
340  typename std::vector<_T>::reverse_iterator it;
341  nbh.resize(nbhQueue.size());
342  for (it=nbh.rbegin(); it!=nbh.rend(); it++, nbhQueue.pop())
343  *it = *nbhQueue.top().first;
344  }
345 
347  class Node
348  {
349  public:
352  Node(int degree, int capacity, const _T& pivot)
353  : degree_(degree), pivot_(pivot),
354  minRadius_(std::numeric_limits<double>::infinity()),
356  maxRange_(degree, maxRadius_)
357  {
358  // The "+1" is needed because we add an element before we check whether to split
359  data_.reserve(capacity+1);
360  }
361 
362  ~Node()
363  {
364  for (unsigned int i=0; i<children_.size(); ++i)
365  delete children_[i];
366  }
367 
370  void updateRadius(double dist)
371  {
372  if (minRadius_ > dist)
373  minRadius_ = dist;
374  if (maxRadius_ < dist)
375  maxRadius_ = dist;
376  }
380  void updateRange(unsigned int i, double dist)
381  {
382  if (minRange_[i] > dist)
383  minRange_[i] = dist;
384  if (maxRange_[i] < dist)
385  maxRange_[i] = dist;
386  }
388  void add(GNAT& gnat, const _T& data)
389  {
390  if (children_.size()==0)
391  {
392  data_.push_back(data);
393  gnat.size_++;
394  if (needToSplit(gnat))
395  {
396  if (gnat.removed_.size() > 0)
397  gnat.rebuildDataStructure();
398  else if (gnat.size_ >= gnat.rebuildSize_)
399  {
400  gnat.rebuildSize_ <<= 1;
401  gnat.rebuildDataStructure();
402  }
403  else
404  split(gnat);
405  }
406  }
407  else
408  {
409  std::vector<double> dist(children_.size());
410  double minDist = dist[0] = gnat.distFun_(data, children_[0]->pivot_);
411  int minInd = 0;
412 
413  for (unsigned int i=1; i<children_.size(); ++i)
414  if ((dist[i] = gnat.distFun_(data, children_[i]->pivot_)) < minDist)
415  {
416  minDist = dist[i];
417  minInd = i;
418  }
419  for (unsigned int i=0; i<children_.size(); ++i)
420  children_[i]->updateRange(minInd, dist[i]);
421  children_[minInd]->updateRadius(minDist);
422  children_[minInd]->add(gnat, data);
423  }
424  }
426  bool needToSplit(const GNAT& gnat) const
427  {
428  unsigned int sz = data_.size();
429  return sz > gnat.maxNumPtsPerLeaf_ && sz > degree_;
430  }
434  void split(GNAT& gnat)
435  {
436  std::vector<std::vector<double> > dists;
437  std::vector<unsigned int> pivots;
438 
439  children_.reserve(degree_);
440  gnat.pivotSelector_.kcenters(data_, degree_, pivots, dists);
441  for(unsigned int i=0; i<pivots.size(); i++)
442  children_.push_back(new Node(degree_, gnat.maxNumPtsPerLeaf_, data_[pivots[i]]));
443  degree_ = pivots.size(); // in case fewer than degree_ pivots were found
444  for (unsigned int j=0; j<data_.size(); ++j)
445  {
446  unsigned int k = 0;
447  for (unsigned int i=1; i<degree_; ++i)
448  if (dists[j][i] < dists[j][k])
449  k = i;
450  Node* child = children_[k];
451  if (j != pivots[k])
452  {
453  child->data_.push_back(data_[j]);
454  child->updateRadius(dists[j][k]);
455  }
456  for (unsigned int i=0; i<degree_; ++i)
457  children_[i]->updateRange(k, dists[j][i]);
458  }
459 
460  for (unsigned int i=0; i<degree_; ++i)
461  {
462  // make sure degree lies between minDegree_ and maxDegree_
463  children_[i]->degree_ = std::min(std::max(
464  degree_ * (unsigned int)(children_[i]->data_.size() / data_.size()),
465  gnat.minDegree_), gnat.maxDegree_);
466  // singleton
467  if (children_[i]->minRadius_ == std::numeric_limits<double>::infinity())
468  children_[i]->minRadius_ = children_[i]->maxRadius_ = 0.;
469  }
470  // this does more than clear(); it also sets capacity to 0 and frees the memory
471  std::vector<_T> tmp;
472  data_.swap(tmp);
473  // check if new leaves need to be split
474  for (unsigned int i=0; i<degree_; ++i)
475  if (children_[i]->needToSplit(gnat))
476  children_[i]->split(gnat);
477  }
478 
480  bool insertNeighborK(NearQueue& nbh, std::size_t k, const _T& data, const _T& key, double dist) const
481  {
482  if (nbh.size() < k)
483  {
484  nbh.push(std::make_pair(&data, dist));
485  return true;
486  }
487  else if (dist < nbh.top().second ||
488  (dist < std::numeric_limits<double>::epsilon() && data==key))
489  {
490  nbh.pop();
491  nbh.push(std::make_pair(&data, dist));
492  return true;
493  }
494  return false;
495  }
496 
502  void nearestK(const GNAT& gnat, const _T &data, std::size_t k,
503  NearQueue& nbh, NodeQueue& nodeQueue, bool& isPivot) const
504  {
505  for (unsigned int i=0; i<data_.size(); ++i)
506  if (!gnat.isRemoved(data_[i]))
507  {
508  if (insertNeighborK(nbh, k, data_[i], data, gnat.distFun_(data, data_[i])))
509  isPivot = false;
510  }
511  if (children_.size() > 0)
512  {
513  double dist;
514  Node* child;
515  std::vector<double> distToPivot(children_.size());
516  std::vector<int> permutation(children_.size());
517 
518  for (unsigned int i=0; i<permutation.size(); ++i)
519  permutation[i] = i;
520  std::random_shuffle(permutation.begin(), permutation.end());
521 
522  for (unsigned int i=0; i<children_.size(); ++i)
523  if (permutation[i] >= 0)
524  {
525  child = children_[permutation[i]];
526  distToPivot[permutation[i]] = gnat.distFun_(data, child->pivot_);
527  if (insertNeighborK(nbh, k, child->pivot_, data, distToPivot[permutation[i]]))
528  isPivot = true;
529  if (nbh.size()==k)
530  {
531  dist = nbh.top().second; // note difference with nearestR
532  for (unsigned int j=0; j<children_.size(); ++j)
533  if (permutation[j] >=0 && i != j &&
534  (distToPivot[permutation[i]] - dist > child->maxRange_[permutation[j]] ||
535  distToPivot[permutation[i]] + dist < child->minRange_[permutation[j]]))
536  permutation[j] = -1;
537  }
538  }
539 
540  dist = nbh.top().second;
541  for (unsigned int i=0; i<children_.size(); ++i)
542  if (permutation[i] >= 0)
543  {
544  child = children_[permutation[i]];
545  if (nbh.size()<k ||
546  (distToPivot[permutation[i]] - dist <= child->maxRadius_ &&
547  distToPivot[permutation[i]] + dist >= child->minRadius_))
548  nodeQueue.push(std::make_pair(child, distToPivot[permutation[i]]));
549  }
550  }
551  }
553  void insertNeighborR(NearQueue& nbh, double r, const _T& data, double dist) const
554  {
555  if (dist <= r)
556  nbh.push(std::make_pair(&data, dist));
557  }
561  void nearestR(const GNAT& gnat, const _T &data, double r, NearQueue& nbh, NodeQueue& nodeQueue) const
562  {
563  double dist = r; //note difference with nearestK
564 
565  for (unsigned int i=0; i<data_.size(); ++i)
566  if (!gnat.isRemoved(data_[i]))
567  insertNeighborR(nbh, r, data_[i], gnat.distFun_(data, data_[i]));
568  if (children_.size() > 0)
569  {
570  Node* child;
571  std::vector<double> distToPivot(children_.size());
572  std::vector<int> permutation(children_.size());
573 
574  for (unsigned int i=0; i<permutation.size(); ++i)
575  permutation[i] = i;
576  std::random_shuffle(permutation.begin(), permutation.end());
577 
578  for (unsigned int i=0; i<children_.size(); ++i)
579  if (permutation[i] >= 0)
580  {
581  child = children_[permutation[i]];
582  distToPivot[i] = gnat.distFun_(data, child->pivot_);
583  insertNeighborR(nbh, r, child->pivot_, distToPivot[i]);
584  for (unsigned int j=0; j<children_.size(); ++j)
585  if (permutation[j] >=0 && i != j &&
586  (distToPivot[i] - dist > child->maxRange_[permutation[j]] ||
587  distToPivot[i] + dist < child->minRange_[permutation[j]]))
588  permutation[j] = -1;
589  }
590 
591  for (unsigned int i=0; i<children_.size(); ++i)
592  if (permutation[i] >= 0)
593  {
594  child = children_[permutation[i]];
595  if (distToPivot[i] - dist <= child->maxRadius_ &&
596  distToPivot[i] + dist >= child->minRadius_)
597  nodeQueue.push(std::make_pair(child, distToPivot[i]));
598  }
599  }
600  }
601 
602  void list(const GNAT& gnat, std::vector<_T> &data) const
603  {
604  if (!gnat.isRemoved(pivot_))
605  data.push_back(pivot_);
606  for (unsigned int i=0; i<data_.size(); ++i)
607  if(!gnat.isRemoved(data_[i]))
608  data.push_back(data_[i]);
609  for (unsigned int i=0; i<children_.size(); ++i)
610  children_[i]->list(gnat, data);
611  }
612 
613  friend std::ostream& operator<<(std::ostream& out, const Node& node)
614  {
615  out << "\ndegree:\t" << node.degree_;
616  out << "\nminRadius:\t" << node.minRadius_;
617  out << "\nmaxRadius:\t" << node.maxRadius_;
618  out << "\nminRange:\t";
619  for (unsigned int i=0; i<node.minRange_.size(); ++i)
620  out << node.minRange_[i] << '\t';
621  out << "\nmaxRange: ";
622  for (unsigned int i=0; i<node.maxRange_.size(); ++i)
623  out << node.maxRange_[i] << '\t';
624  out << "\npivot:\t" << node.pivot_;
625  out << "\ndata: ";
626  for (unsigned int i=0; i<node.data_.size(); ++i)
627  out << node.data_[i] << '\t';
628  out << "\nthis:\t" << &node;
629  out << "\nchildren:\n";
630  for (unsigned int i=0; i<node.children_.size(); ++i)
631  out << node.children_[i] << '\t';
632  out << '\n';
633  for (unsigned int i=0; i<node.children_.size(); ++i)
634  out << *node.children_[i] << '\n';
635  return out;
636  }
637 
639  unsigned int degree_;
641  const _T pivot_;
643  double minRadius_;
645  double maxRadius_;
648  std::vector<double> minRange_;
651  std::vector<double> maxRange_;
654  std::vector<_T> data_;
657  std::vector<Node*> children_;
658  };
659 
663  unsigned int degree_;
668  unsigned int minDegree_;
673  unsigned int maxDegree_;
676  unsigned int maxNumPtsPerLeaf_;
678  std::size_t size_;
681  std::size_t rebuildSize_;
685  std::size_t removedCacheSize_;
689  boost::unordered_set<const _T*> removed_;
690  };
691 
692 }
693 
694 #endif