MLPACK  1.0.11
fastmks_rules.hpp
Go to the documentation of this file.
1 
22 #ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
23 #define __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
24 
25 #include <mlpack/core.hpp>
27 
28 #include "../neighbor_search/ns_traversal_info.hpp"
29 
30 namespace mlpack {
31 namespace fastmks {
32 
36 template<typename KernelType, typename TreeType>
38 {
39  public:
40  FastMKSRules(const arma::mat& referenceSet,
41  const arma::mat& querySet,
42  arma::Mat<size_t>& indices,
43  arma::mat& products,
44  KernelType& kernel);
45 
47  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
48 
57  double Score(const size_t queryIndex, TreeType& referenceNode);
58 
67  double Score(TreeType& queryNode, TreeType& referenceNode);
68 
80  double Rescore(const size_t queryIndex,
81  TreeType& referenceNode,
82  const double oldScore) const;
83 
95  double Rescore(TreeType& queryNode,
96  TreeType& referenceNode,
97  const double oldScore) const;
98 
100  size_t BaseCases() const { return baseCases; }
102  size_t& BaseCases() { return baseCases; }
103 
105  size_t Scores() const { return scores; }
107  size_t& Scores() { return scores; }
108 
110 
111  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
112  TraversalInfoType& TraversalInfo() { return traversalInfo; }
113 
114  private:
116  const arma::mat& referenceSet;
118  const arma::mat& querySet;
119 
121  arma::Mat<size_t>& indices;
123  arma::mat& products;
124 
126  arma::vec queryKernels;
128  arma::vec referenceKernels;
129 
131  KernelType& kernel;
132 
138  double lastKernel;
139 
141  double CalculateBound(TreeType& queryNode) const;
142 
144  void InsertNeighbor(const size_t queryIndex,
145  const size_t pos,
146  const size_t neighbor,
147  const double distance);
148 
150  size_t baseCases;
152  size_t scores;
153 
154  TraversalInfoType traversalInfo;
155 };
156 
157 }; // namespace fastmks
158 }; // namespace mlpack
159 
160 // Include implementation.
161 #include "fastmks_rules_impl.hpp"
162 
163 #endif
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Compute the base case (kernel value) between two points.
const TraversalInfoType & TraversalInfo() const
size_t Scores() const
Get the number of times Score() was called.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:31
const arma::mat & querySet
The query dataset.
arma::vec queryKernels
Cached query set self-kernels (|| q || for each q).
size_t & BaseCases()
Modify the number of times BaseCase() was called.
double lastKernel
The last kernel evaluation resulting from BaseCase().
Traversal information for NeighborSearch.
size_t & Scores()
Modify the number of times Score() was called.
arma::vec referenceKernels
Cached reference set self-kernels (|| r || for each r).
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
TraversalInfoType & TraversalInfo()
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
size_t baseCases
For benchmarking.
The base case and pruning rules for FastMKS (fast max-kernel search).
arma::Mat< size_t > & indices
The indices of the maximum kernel results.
size_t lastReferenceIndex
The last reference index BaseCase() was called with.
arma::mat & products
The maximum kernels.
neighbor::NeighborSearchTraversalInfo< TreeType > TraversalInfoType
FastMKSRules(const arma::mat &referenceSet, const arma::mat &querySet, arma::Mat< size_t > &indices, arma::mat &products, KernelType &kernel)
size_t lastQueryIndex
The last query index BaseCase() was called with.
const arma::mat & referenceSet
The reference dataset.
size_t BaseCases() const
Get the number of times BaseCase() was called.
void InsertNeighbor(const size_t queryIndex, const size_t pos, const size_t neighbor, const double distance)
Utility function to insert neighbor into list of results.
KernelType & kernel
The instantiated kernel.
double CalculateBound(TreeType &queryNode) const
Calculate the bound for a given query node.
size_t scores
For benchmarking.