SHOGUN  6.0.0
LSHKNNSolver.cpp
Go to the documentation of this file.
1 /* This software is distributed under BSD 3-clause license (see LICENSE file).
2  *
3  * Copyright (c) 2012-2013 Sergey Lisitsyn
4  */
5 
9 #include <shogun/lib/Signal.h>
10 
11 using namespace shogun;
12 using namespace Eigen;
13 
14 #ifdef HAVE_CXX11
15 #include <shogun/lib/external/falconn/lsh_nn_table.h>
16 
17 CLSHKNNSolver::CLSHKNNSolver(const int32_t k, const float64_t q, const int32_t num_classes, const int32_t min_label, const SGVector<int32_t> train_labels, const int32_t lsh_l, const int32_t lsh_t):
18 CKNNSolver(k, q, num_classes, min_label, train_labels)
19 {
20  init();
21 
22  m_lsh_l=lsh_l;
23  m_lsh_t=lsh_t;
24 }
25 
26 CMulticlassLabels* CLSHKNNSolver::classify_objects(CDistance* knn_distance, const int32_t num_lab, SGVector<int32_t>& train_lab, SGVector<float64_t>& classes) const
27 {
28  CMulticlassLabels* output=new CMulticlassLabels(num_lab);
29  CDenseFeatures<float64_t>* features = dynamic_cast<CDenseFeatures<float64_t>*>(knn_distance->get_lhs());
30  std::vector<falconn::DenseVector<double>> feats;
31  for(int32_t i=0; i < features->get_num_vectors(); i++)
32  {
33  int32_t len;
34  bool free;
35  float64_t* vec = features->get_feature_vector(i, len, free);
36  falconn::DenseVector<double> temp = Map<VectorXd> (vec, len);
37  feats.push_back(temp);
38  }
39 
40  falconn::LSHConstructionParameters params
41  = falconn::get_default_parameters<falconn::DenseVector<double>>(features->get_num_vectors(),
42  features->get_num_features(),
43  falconn::DistanceFunction::EuclideanSquared,
44  true);
45  SG_UNREF(features);
46  if (m_lsh_l && m_lsh_t)
47  params.l = m_lsh_l;
48 
49  auto lsh_table = falconn::construct_table<falconn::DenseVector<double>>(feats, params);
50  if (m_lsh_t)
51  lsh_table->set_num_probes(m_lsh_t);
52 
53  CDenseFeatures<float64_t>* query_features = dynamic_cast<CDenseFeatures<float64_t>*>(knn_distance->get_rhs());
54  std::vector<falconn::DenseVector<double>> query_feats;
55 
56  SGMatrix<index_t> NN (m_k, query_features->get_num_vectors());
57  for(index_t i=0; i < query_features->get_num_vectors(); i++)
58  {
59  int32_t len;
60  bool free;
61  float64_t* vec = query_features->get_feature_vector(i, len, free);
62  falconn::DenseVector<double> temp = Map<VectorXd> (vec, len);
63  auto indices = new std::vector<int32_t> ();
64  lsh_table->find_k_nearest_neighbors(temp, (int_fast64_t)m_k, indices);
65  sg_memcpy(NN.get_column_vector(i), indices->data(), sizeof(int32_t)*m_k);
66  delete indices;
67  }
68 
69  for (index_t i=0; i<num_lab && (!CSignal::cancel_computations()); i++)
70  {
71  //write the labels of the k nearest neighbors from theirs indices
72  for (index_t j=0; j<m_k; j++)
73  train_lab[j] = m_train_labels[ NN(j,i) ];
74 
75  //get the index of the 'nearest' class
76  index_t out_idx = choose_class(classes.vector, train_lab.vector);
77  //write the label of 'nearest' in the output
78  output->set_label(i, out_idx + m_min_label);
79  }
80  SG_UNREF(query_features);
81 
82  return output;
83 }
84 
85 SGVector<int32_t> CLSHKNNSolver::classify_objects_k(CDistance* d, const int32_t num_lab, SGVector<int32_t>& train_lab, SGVector<int32_t>& classes) const
86 {
88  return 0;
89 }
90 #endif
Class Distance, a base class for all the distances used in the Shogun toolbox.
Definition: Distance.h:87
int32_t index_t
Definition: common.h:72
CFeatures * get_lhs()
Definition: Distance.h:218
Definition: SGMatrix.h:24
#define SG_NOTIMPLEMENTED
Definition: SGIO.h:138
CFeatures * get_rhs()
Definition: Distance.h:224
ST * get_feature_vector(int32_t num, int32_t &len, bool &dofree)
bool set_label(int32_t idx, float64_t label)
Multiclass Labels for multi-class classification.
double float64_t
Definition: common.h:60
static bool cancel_computations()
Definition: Signal.h:111
CDotFeatures * features
int32_t get_num_features() const
#define SG_UNREF(x)
Definition: SGObject.h:53
virtual int32_t get_num_vectors() const
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18

SHOGUN Machine Learning Toolbox - Documentation