Tapkee
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
neighbors/vptree.hpp
Go to the documentation of this file.
1 /* This software is distributed under BSD 3-clause license (see LICENSE file).
2  *
3  * Copyright (c) 2012-2013 Laurens van der Maaten, Sergey Lisitsyn
4  */
5 
6 #ifndef TAPKEE_VPTREE_H_
7 #define TAPKEE_VPTREE_H_
8 
9 /* Tapkee includes */
10 #include <tapkee/defines.hpp>
11 /* End of Tapkee includes */
12 
13 #include <vector>
14 #include <queue>
15 #include <algorithm>
16 #include <limits>
17 
18 namespace tapkee
19 {
20 namespace tapkee_internal
21 {
22 
23 template<class Type, class RandomAccessIterator, class DistanceCallback>
24 struct compare_impl;
25 
26 template<class RandomAccessIterator, class DistanceCallback>
28 {
29  DistanceCallback callback;
30  const RandomAccessIterator item;
31  DistanceComparator(const DistanceCallback& c, const RandomAccessIterator& i) :
32  callback(c), item(i) {}
33  inline bool operator()(const RandomAccessIterator& a, const RandomAccessIterator& b)
34  {
36  (callback,item,a,b);
37  }
38 };
39 
40 struct KernelType;
41 
42 template<class RandomAccessIterator, class DistanceCallback>
43 struct compare_impl<KernelType,RandomAccessIterator,DistanceCallback>
44 {
45  inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
46  const RandomAccessIterator& a, const RandomAccessIterator& b)
47  {
48  return (-2*callback(item,a) + callback(a,a)) < (-2*callback(item,b) + callback(b,b));
49  }
50 };
51 
52 struct DistanceType;
53 
54 template<class RandomAccessIterator, class DistanceCallback>
55 struct compare_impl<DistanceType,RandomAccessIterator,DistanceCallback>
56 {
57  inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
58  const RandomAccessIterator& a, const RandomAccessIterator& b)
59  {
60  return callback(item,a) < callback(item,b);
61  }
62 };
63 
64 template<class RandomAccessIterator, class DistanceCallback>
66 {
67 public:
68 
69  // Default constructor
70  VantagePointTree(RandomAccessIterator b, RandomAccessIterator e, DistanceCallback c) :
71  begin(b), items(), callback(c), tau(0.0), root(0)
72  {
73  items.reserve(e-b);
74  for (RandomAccessIterator i=b; i!=e; ++i)
75  items.push_back(i);
76  root = buildFromPoints(0, items.size());
77  }
78 
79  // Destructor
81  {
82  delete root;
83  }
84 
85  // Function that uses the tree to find the k nearest neighbors of target
86  std::vector<IndexType> search(const RandomAccessIterator& target, int k)
87  {
88  std::vector<IndexType> results;
89  // Use a priority queue to store intermediate results on
90  std::priority_queue<HeapItem> heap;
91 
92  // Variable that tracks the distance to the farthest point in our results
93  tau = std::numeric_limits<double>::max();
94 
95  // Perform the searcg
96  search(root, target, k, heap);
97 
98  // Gather final results
99  results.reserve(k);
100  while(!heap.empty()) {
101  results.push_back(items[heap.top().index]-begin);
102  heap.pop();
103  }
104  return results;
105  }
106 
107 private:
108 
111 
112  RandomAccessIterator begin;
113  std::vector<RandomAccessIterator> items;
114  DistanceCallback callback;
115  double tau;
116 
117  struct Node
118  {
119  int index;
120  double threshold;
123 
124  Node() :
125  index(0), threshold(0.),
126  left(0), right(0)
127  {
128  }
129 
130  ~Node()
131  {
132  delete left;
133  delete right;
134  }
135 
136  Node(const Node&);
137  Node& operator=(const Node&);
138 
139  }* root;
140 
141  struct HeapItem {
142  HeapItem(int i, double d) :
143  index(i), distance(d) {}
144  int index;
145  double distance;
146  bool operator<(const HeapItem& o) const {
147  return distance < o.distance;
148  }
149  };
150 
151 
152  Node* buildFromPoints(int lower, int upper)
153  {
154  if (upper == lower)
155  {
156  return NULL;
157  }
158 
159  Node* node = new Node();
160  node->index = lower;
161 
162  if (upper - lower > 1)
163  {
164  int i = (int) (tapkee::uniform_random() * (upper - lower - 1)) + lower;
165  std::swap(items[lower], items[i]);
166 
167  int median = (upper + lower) / 2;
168  std::nth_element(items.begin() + lower + 1, items.begin() + median, items.begin() + upper,
170 
171  node->threshold = callback.distance(items[lower], items[median]);
172  node->index = lower;
173  node->left = buildFromPoints(lower + 1, median);
174  node->right = buildFromPoints(median, upper);
175  }
176 
177  return node;
178  }
179 
180  void search(Node* node, const RandomAccessIterator& target, int k, std::priority_queue<HeapItem>& heap)
181  {
182  if (node == NULL)
183  return;
184 
185  double distance = callback.distance(items[node->index], target);
186 
187  if (distance < tau)
188  {
189  if (heap.size() == static_cast<size_t>(k))
190  heap.pop();
191 
192  heap.push(HeapItem(node->index, distance));
193 
194  if (heap.size() == static_cast<size_t>(k))
195  tau = heap.top().distance;
196  }
197 
198  if (node->left == NULL && node->right == NULL)
199  {
200  return;
201  }
202 
203  if (distance < node->threshold)
204  {
205  if ((distance - tau) <= node->threshold)
206  search(node->left, target, k, heap);
207 
208  if ((distance + tau) >= node->threshold)
209  search(node->right, target, k, heap);
210  }
211  else
212  {
213  if ((distance + tau) >= node->threshold)
214  search(node->right, target, k, heap);
215 
216  if ((distance - tau) <= node->threshold)
217  search(node->left, target, k, heap);
218  }
219  }
220 };
221 
222 }
223 }
224 #endif
ScalarType distance(Callback &cb, const CoverTreePoint< RandomAccessIterator > &l, const CoverTreePoint< RandomAccessIterator > &r, ScalarType upper_bound)
bool operator()(DistanceCallback &callback, const RandomAccessIterator &item, const RandomAccessIterator &a, const RandomAccessIterator &b)
bool operator()(DistanceCallback &callback, const RandomAccessIterator &item, const RandomAccessIterator &a, const RandomAccessIterator &b)
DistanceComparator(const DistanceCallback &c, const RandomAccessIterator &i)
ScalarType uniform_random()
Definition: random.hpp:30
void search(Node *node, const RandomAccessIterator &target, int k, std::priority_queue< HeapItem > &heap)
Node * buildFromPoints(int lower, int upper)
bool operator()(const RandomAccessIterator &a, const RandomAccessIterator &b)
VantagePointTree & operator=(const VantagePointTree &)
std::vector< RandomAccessIterator > items
std::vector< IndexType > search(const RandomAccessIterator &target, int k)
VantagePointTree(RandomAccessIterator b, RandomAccessIterator e, DistanceCallback c)
struct tapkee::tapkee_internal::VantagePointTree::Node * root