SHOGUN  6.0.0
CoverTreeKNNSolver.cpp
Go to the documentation of this file.
1 /* This software is distributed under BSD 3-clause license (see LICENSE file).
2  *
3  * Copyright (c) 2012-2013 Sergey Lisitsyn
4  */
5 
8 
9 using namespace shogun;
10 
11 CCoverTreeKNNSolver::CCoverTreeKNNSolver(const int32_t k, const float64_t q, const int32_t num_classes, const int32_t min_label, const SGVector<int32_t> train_labels):
12 CKNNSolver(k, q, num_classes, min_label, train_labels) { /* nothing to do */ }
13 
14 CMulticlassLabels* CCoverTreeKNNSolver::classify_objects(CDistance* knn_distance, const int32_t num_lab, SGVector<int32_t>& train_lab, SGVector<float64_t>& classes) const
15 {
16  CMulticlassLabels* output=new CMulticlassLabels(num_lab);
17 
18  // m_q != 1.0 not supported with cover tree because the neighbors
19  // are not retrieved in increasing order of distance to the query
20  if ( m_q != 1.0 )
21  SG_INFO("q != 1.0 not supported with cover tree, using q = 1\n")
22 
23  // From the sets of features (lhs and rhs) stored in distance,
24  // build arrays of cover tree points
25  v_array< CJLCoverTreePoint > set_of_points =
26  parse_points(knn_distance, FC_LHS);
27  v_array< CJLCoverTreePoint > set_of_queries =
28  parse_points(knn_distance, FC_RHS);
29 
30  // Build the cover trees, one for the test vectors (rhs features)
31  // and another for the training vectors (lhs features)
32  CFeatures* r = knn_distance->replace_rhs( knn_distance->get_lhs() );
33  node< CJLCoverTreePoint > top = batch_create(set_of_points);
34  CFeatures* l = knn_distance->replace_lhs(r);
35  knn_distance->replace_rhs(r);
36  node< CJLCoverTreePoint > top_query = batch_create(set_of_queries);
37 
38  // Get the k nearest neighbors to all the test vectors (batch method)
39  knn_distance->replace_lhs(l);
41  k_nearest_neighbor(top, top_query, res, m_k);
42 
43 if (io->get_loglevel()<= MSG_DEBUG)
44 {
45  SG_DEBUG("\nJL Results:\n")
46  for ( int32_t i = 0 ; i < res.index ; ++i )
47  {
48  for ( int32_t j = 0 ; j < res[i].index ; ++j )
49  {
50  SG_DEBUG("%d ", res[i][j].m_index);
51  }
52  SG_DEBUG("\n");
53  }
54  SG_DEBUG("\n")
55 }
56 
57  for ( index_t i = 0 ; i < res.index ; ++i )
58  {
59  // Translate from indices to labels of the nearest neighbors
60  for ( index_t j = 0; j < m_k; ++j )
61  // The first index in res[i] points to the test vector
62  train_lab[j] = m_train_labels.vector[ res[i][j+1].m_index ];
63 
64  // Get the index of the 'nearest' class
65  index_t out_idx = choose_class(classes.vector, train_lab.vector);
66  output->set_label(res[i][0].m_index, out_idx+m_min_label);
67  }
68 
69 
70  return output;
71 }
72 
74 {
75  SGVector<int32_t> output(m_k*num_lab);
76 
77  //allocation for distances to nearest neighbors
78  SGVector<float64_t> dists(m_k);
79 
80  // From the sets of features (lhs and rhs) stored in distance,
81  // build arrays of cover tree points
82  v_array< CJLCoverTreePoint > set_of_points =
83  parse_points(knn_distance, FC_LHS);
84  v_array< CJLCoverTreePoint > set_of_queries =
85  parse_points(knn_distance, FC_RHS);
86 
87  // Build the cover trees, one for the test vectors (rhs features)
88  // and another for the training vectors (lhs features)
89  CFeatures* r = knn_distance->replace_rhs( knn_distance->get_lhs() );
90  node< CJLCoverTreePoint > top = batch_create(set_of_points);
91  CFeatures* l = knn_distance->replace_lhs(r);
92  knn_distance->replace_rhs(r);
93  node< CJLCoverTreePoint > top_query = batch_create(set_of_queries);
94 
95  // Get the k nearest neighbors to all the test vectors (batch method)
96  knn_distance->replace_lhs(l);
98  k_nearest_neighbor(top, top_query, res, m_k);
99 
100  for ( index_t i = 0 ; i < res.index ; ++i )
101  {
102  // Handle the fact that cover tree doesn't return neighbors
103  // ordered by distance
104 
105  for ( index_t j = 0 ; j < m_k ; ++j )
106  {
107  // The first index in res[i] points to the test vector
108  dists[j] = knn_distance->distance(res[i][j+1].m_index,
109  res[i][0].m_index);
110  train_lab[j] = m_train_labels.vector[
111  res[i][j+1].m_index ];
112  }
113 
114  // Now we get the indices to the neighbors sorted by distance
115  CMath::qsort_index(dists.vector, train_lab.vector, m_k);
116 
117  choose_class_for_multiple_k(output.vector+res[i][0].m_index, classes.vector,
118  train_lab.vector, num_lab);
119  }
120 
121  return output;
122 }
#define SG_INFO(...)
Definition: SGIO.h:117
int32_t m_k
the k parameter in KNN
Definition: KNNSolver.h:94
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
SGVector< int32_t > m_train_labels
Definition: KNNSolver.h:106
int32_t index_t
Definition: common.h:72
CFeatures * get_lhs()
Definition: Distance.h:218
static void qsort_index(T1 *output, T2 *index, uint32_t size)
Definition: Math.h:2223
node< P > batch_create(v_array< P > points)
Definition: JLCoverTree.h:299
int32_t choose_class(float64_t *classes, const int32_t *train_lab) const
Definition: KNNSolver.cpp:35
virtual CFeatures * replace_lhs(CFeatures *lhs)
Definition: Distance.cpp:165
virtual SGVector< int32_t > classify_objects_k(CDistance *d, const int32_t num_lab, SGVector< int32_t > &train_lab, SGVector< int32_t > &classes) const
Class v_array taken directly from JL&#39;s implementation.
bool set_label(int32_t idx, float64_t label)
v_array< CJLCoverTreePoint > parse_points(CDistance *distance, EFeaturesContainer fc)
Multiclass Labels for multi-class classification.
EMessageType get_loglevel() const
Definition: SGIO.cpp:315
void choose_class_for_multiple_k(int32_t *output, int32_t *classes, const int32_t *train_lab, const int32_t step) const
Definition: KNNSolver.cpp:62
double float64_t
Definition: common.h:60
virtual CFeatures * replace_rhs(CFeatures *rhs)
Definition: Distance.cpp:147
int32_t m_min_label
smallest label, i.e. -1
Definition: KNNSolver.h:103
virtual float64_t distance(int32_t idx_a, int32_t idx_b)
Definition: Distance.cpp:183
#define SG_DEBUG(...)
Definition: SGIO.h:106
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
void k_nearest_neighbor(const node< P > &top_node, const node< P > &query, v_array< v_array< P > > &results, int k)
Definition: JLCoverTree.h:832
virtual CMulticlassLabels * classify_objects(CDistance *d, const int32_t num_lab, SGVector< int32_t > &train_lab, SGVector< float64_t > &classes) const
float64_t m_q
parameter q of rank weighting
Definition: KNNSolver.h:97

SHOGUN Machine Learning Toolbox - Documentation