Tapkee
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
neighbors.hpp
Go to the documentation of this file.
1 /* This software is distributed under BSD 3-clause license (see LICENSE file).
2  *
3  * Copyright (c) 2012-2013 Sergey Lisitsyn, Fernando J. Iglesias Garcia
4  */
5 
6 #ifndef TAPKEE_NEIGHBORS_H_
7 #define TAPKEE_NEIGHBORS_H_
8 
9 /* Tapkee includes */
10 #include <tapkee/defines.hpp>
11 #ifdef TAPKEE_USE_LGPL_COVERTREE
13 #endif
16 /* End of Tapkee includes */
17 
18 #include <vector>
19 #include <utility>
20 #include <algorithm>
21 
22 namespace tapkee
23 {
24 namespace tapkee_internal
25 {
26 
27 template <class DistanceRecord>
29 {
30  inline bool operator()(const DistanceRecord& l, const DistanceRecord& r) const
31  {
32  return (l.second < r.second);
33  }
34 };
35 
36 struct KernelType
37 {
38 };
39 
40 template <class RandomAccessIterator, class Callback>
42 {
43  KernelDistance(const Callback& cb) : callback(cb) { }
44  inline ScalarType operator()(const RandomAccessIterator& l, const RandomAccessIterator& r)
45  {
46  return callback.kernel(*l,*r);
47  }
48  inline ScalarType distance(const RandomAccessIterator& l, const RandomAccessIterator& r)
49  {
50  return sqrt(callback.kernel(*l,*l) - 2*callback.kernel(*l,*r) + callback.kernel(*r,*r));
51  }
52  typedef KernelType type;
53  Callback callback;
54 };
55 
57 {
58 };
59 
60 template <class RandomAccessIterator, class Callback>
62 {
63  PlainDistance(const Callback& cb) : callback(cb) { }
64  inline ScalarType operator()(const RandomAccessIterator& l, const RandomAccessIterator& r)
65  {
66  return callback.distance(*l,*r);
67  }
68  inline ScalarType distance(const RandomAccessIterator& l, const RandomAccessIterator& r)
69  {
70  return callback.distance(*l,*r);
71  }
72  typedef DistanceType type;
73  Callback callback;
74 };
75 
76 #ifdef TAPKEE_USE_LGPL_COVERTREE
77 template <class RandomAccessIterator, class Callback>
78 Neighbors find_neighbors_covertree_impl(RandomAccessIterator begin, RandomAccessIterator end,
79  Callback callback, IndexType k)
80 {
81  timed_context context("Covertree-based neighbors search");
82 
83  typedef CoverTreePoint<RandomAccessIterator> TreePoint;
84  v_array<TreePoint> points;
85  for (RandomAccessIterator iter=begin; iter!=end; ++iter)
86  push(points, TreePoint(iter, callback(iter,iter)));
87 
88  node<TreePoint> ct = batch_create(callback, points);
89 
91  ++k; // because one of the neighbors will be the actual query point
92  k_nearest_neighbor(callback,ct,ct,res,k);
93 
94  Neighbors neighbors;
95  neighbors.resize(end-begin);
96  assert(end-begin==res.index);
97  for (int i=0; i<res.index; ++i)
98  {
99  LocalNeighbors local_neighbors;
100  local_neighbors.reserve(k);
101 
102  for (IndexType j=1; j<=k; ++j) // j=0 is the query point
103  {
104  // The actual query point is found as a neighbor, just ignore it
105  if (res[i][j].iter_-begin==res[i][0].iter_-begin)
106  continue;
107  local_neighbors.push_back(res[i][j].iter_-begin);
108  }
109  neighbors[res[i][0].iter_-begin] = local_neighbors;
110  free(res[i].elements);
111  };
112  free(res.elements);
113  free_children(ct);
114  free(points.elements);
115  return neighbors;
116 }
117 #endif
118 
119 template <class RandomAccessIterator, class Callback>
120 Neighbors find_neighbors_bruteforce_impl(const RandomAccessIterator& begin, const RandomAccessIterator& end,
121  Callback callback, IndexType k)
122 {
123  timed_context context("Distance sorting based neighbors search");
124  typedef std::pair<RandomAccessIterator, ScalarType> DistanceRecord;
125  typedef std::vector<DistanceRecord> Distances;
126 
127  Neighbors neighbors;
128  neighbors.reserve(end-begin);
129  for (RandomAccessIterator iter=begin; iter!=end; ++iter)
130  {
131  Distances distances;
132  for (RandomAccessIterator around_iter=begin; around_iter!=end; ++around_iter)
133  distances.push_back(std::make_pair(around_iter, callback.distance(iter,around_iter)));
134 
135  std::nth_element(distances.begin(),distances.begin()+k+1,distances.end(),
137 
138  LocalNeighbors local_neighbors;
139  local_neighbors.reserve(k);
140  for (typename Distances::const_iterator neighbors_iter=distances.begin();
141  neighbors_iter!=distances.begin()+k+1; ++neighbors_iter)
142  {
143  if (neighbors_iter->first != iter)
144  local_neighbors.push_back(neighbors_iter->first - begin);
145  }
146  neighbors.push_back(local_neighbors);
147  }
148  return neighbors;
149 }
150 
151 template <class RandomAccessIterator, class Callback>
152 Neighbors find_neighbors_vptree_impl(const RandomAccessIterator& begin, const RandomAccessIterator& end,
153  Callback callback, IndexType k)
154 {
155  timed_context context("VP-Tree based neighbors search");
156 
157  Neighbors neighbors;
158  neighbors.reserve(end-begin);
159 
160  VantagePointTree<RandomAccessIterator,Callback> tree(begin,end,callback);
161 
162  for (RandomAccessIterator i=begin; i!=end; ++i)
163  {
164  LocalNeighbors local_neighbors = tree.search(i,k+1);
165  std::remove(local_neighbors.begin(),local_neighbors.end(),i-begin);
166  neighbors.push_back(local_neighbors);
167  }
168 
169  return neighbors;
170 }
171 
172 template <class RandomAccessIterator, class Callback>
173 Neighbors find_neighbors(NeighborsMethod method, const RandomAccessIterator& begin,
174  const RandomAccessIterator& end, const Callback& callback,
175  IndexType k, bool check_connectivity)
176 {
177  if (k > static_cast<IndexType>(end-begin-1))
178  {
179  LoggingSingleton::instance().message_warning("Number of neighbors is greater than number of objects to embed. "
180  "Using greatest possible number of neighbors.");
181  k = static_cast<IndexType>(end-begin-1);
182  }
183  LoggingSingleton::instance().message_info("Using the " + get_neighbors_method_name(method) + " neighbors computation method.");
184  Neighbors neighbors;
185  switch (method)
186  {
187  case Brute: neighbors = find_neighbors_bruteforce_impl(begin,end,callback,k); break;
188  case VpTree: neighbors = find_neighbors_vptree_impl(begin,end,callback,k); break;
189 #ifdef TAPKEE_USE_LGPL_COVERTREE
190  case CoverTree: neighbors = find_neighbors_covertree_impl(begin,end,callback,k); break;
191 #endif
192  default: break;
193  }
194 
195  if (check_connectivity)
196  {
197  if (!is_connected(begin,end,neighbors))
198  LoggingSingleton::instance().message_warning("The neighborhood graph is not connected.");
199  }
200  return neighbors;
201 }
202 
203 } // End of namespace tapkee
204 } // End of namespace tapkee_internal
205 
206 #endif
ScalarType distance(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:68
Neighbors find_neighbors(NeighborsMethod method, const RandomAccessIterator &begin, const RandomAccessIterator &end, const Callback &callback, IndexType k, bool check_connectivity)
Definition: neighbors.hpp:173
void k_nearest_neighbor(DistanceCallback &dcb, const node< P > &top_node, const node< P > &query, v_array< v_array< P > > &results, int k)
Definition: covertree.hpp:828
ScalarType operator()(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:44
node< P > batch_create(DistanceCallback &dcb, v_array< P > points)
Definition: covertree.hpp:299
bool operator()(const DistanceRecord &l, const DistanceRecord &r) const
Definition: neighbors.hpp:30
ScalarType distance(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:48
Class v_array taken directly from JL&#39;s implementation.
Neighbors find_neighbors_covertree_impl(RandomAccessIterator begin, RandomAccessIterator end, Callback callback, IndexType k)
Definition: neighbors.hpp:78
double ScalarType
default scalar value (can be overrided with TAPKEE_CUSTOM_INTERNAL_NUMTYPE define) ...
Definition: types.hpp:15
std::string get_neighbors_method_name(NeighborsMethod m)
Definition: naming.hpp:42
void free_children(const node< P > &n)
Definition: covertree.hpp:69
TAPKEE_INTERNAL_VECTOR< tapkee::IndexType > LocalNeighbors
Definition: synonyms.hpp:39
int IndexType
indexing type (non-overridable) set to int for compatibility with OpenMP 2.0
Definition: types.hpp:19
ScalarType operator()(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:64
void push(v_array< T > &v, const T &new_ele)
void message_info(const std::string &msg)
Definition: logging.hpp:113
Covertree-based method with approximate time complexity. Recommended to be used as a default method...
TAPKEE_INTERNAL_VECTOR< tapkee::tapkee_internal::LocalNeighbors > Neighbors
Definition: synonyms.hpp:40
Neighbors find_neighbors_vptree_impl(const RandomAccessIterator &begin, const RandomAccessIterator &end, Callback callback, IndexType k)
Definition: neighbors.hpp:152
bool is_connected(RandomAccessIterator begin, RandomAccessIterator end, const Neighbors &neighbors)
Definition: connected.hpp:18
void message_warning(const std::string &msg)
Definition: logging.hpp:114
static LoggingSingleton & instance()
Definition: logging.hpp:100
std::vector< IndexType > search(const RandomAccessIterator &target, int k)
Class Point to use with John Langford&#39;s CoverTree. This class must have some associated functions def...
Brute force method with not least than time complexity. Recommended to be used only in debug purpose...
NeighborsMethod
Neighbors computation methods.
Neighbors find_neighbors_bruteforce_impl(const RandomAccessIterator &begin, const RandomAccessIterator &end, Callback callback, IndexType k)
Definition: neighbors.hpp:120