00001
00022 #ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
00023 #define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
00024
00025 #include <mlpack/core.hpp>
00026 #include <mlpack/core/metrics/lmetric.hpp>
00027 #include "first_point_is_root.hpp"
00028 #include "../statistic.hpp"
00029
00030 namespace mlpack {
00031 namespace tree {
00032
00100 template<typename MetricType = metric::LMetric<2, true>,
00101 typename RootPointPolicy = FirstPointIsRoot,
00102 typename StatisticType = EmptyStatistic>
00103 class CoverTree
00104 {
00105 public:
00106 typedef arma::mat Mat;
00107
00118 CoverTree(const arma::mat& dataset,
00119 const double base = 2.0,
00120 MetricType* metric = NULL);
00121
00131 CoverTree(const arma::mat& dataset,
00132 MetricType& metric,
00133 const double base = 2.0);
00134
00166 CoverTree(const arma::mat& dataset,
00167 const double base,
00168 const size_t pointIndex,
00169 const int scale,
00170 CoverTree* parent,
00171 const double parentDistance,
00172 arma::Col<size_t>& indices,
00173 arma::vec& distances,
00174 size_t nearSetSize,
00175 size_t& farSetSize,
00176 size_t& usedSetSize,
00177 MetricType& metric = NULL);
00178
00195 CoverTree(const arma::mat& dataset,
00196 const double base,
00197 const size_t pointIndex,
00198 const int scale,
00199 CoverTree* parent,
00200 const double parentDistance,
00201 const double furthestDescendantDistance,
00202 MetricType* metric = NULL);
00203
00210 CoverTree(const CoverTree& other);
00211
00215 ~CoverTree();
00216
00219 template<typename RuleType>
00220 class SingleTreeTraverser;
00221
00223 template<typename RuleType>
00224 class DualTreeTraverser;
00225
00227 const arma::mat& Dataset() const { return dataset; }
00228
00230 size_t Point() const { return point; }
00232 size_t Point(const size_t) const { return point; }
00233
00234 bool IsLeaf() const { return (children.size() == 0); }
00235 size_t NumPoints() const { return 1; }
00236
00238 const CoverTree& Child(const size_t index) const { return *children[index]; }
00240 CoverTree& Child(const size_t index) { return *children[index]; }
00241
00243 size_t NumChildren() const { return children.size(); }
00244
00246 const std::vector<CoverTree*>& Children() const { return children; }
00248 std::vector<CoverTree*>& Children() { return children; }
00249
00251 size_t NumDescendants() const;
00252
00254 size_t Descendant(const size_t index) const;
00255
00257 int Scale() const { return scale; }
00259 int& Scale() { return scale; }
00260
00262 double Base() const { return base; }
00264 double& Base() { return base; }
00265
00267 const StatisticType& Stat() const { return stat; }
00269 StatisticType& Stat() { return stat; }
00270
00272 double MinDistance(const CoverTree* other) const;
00273
00276 double MinDistance(const CoverTree* other, const double distance) const;
00277
00279 double MinDistance(const arma::vec& other) const;
00280
00283 double MinDistance(const arma::vec& other, const double distance) const;
00284
00286 double MaxDistance(const CoverTree* other) const;
00287
00290 double MaxDistance(const CoverTree* other, const double distance) const;
00291
00293 double MaxDistance(const arma::vec& other) const;
00294
00297 double MaxDistance(const arma::vec& other, const double distance) const;
00298
00300 math::Range RangeDistance(const CoverTree* other) const;
00301
00304 math::Range RangeDistance(const CoverTree* other, const double distance)
00305 const;
00306
00308 math::Range RangeDistance(const arma::vec& other) const;
00309
00312 math::Range RangeDistance(const arma::vec& other, const double distance)
00313 const;
00314
00316 static bool HasSelfChildren() { return true; }
00317
00319 CoverTree* Parent() const { return parent; }
00321 CoverTree*& Parent() { return parent; }
00322
00324 double ParentDistance() const { return parentDistance; }
00326 double& ParentDistance() { return parentDistance; }
00327
00329 double FurthestPointDistance() const { return 0.0; }
00330
00332 double FurthestDescendantDistance() const
00333 { return furthestDescendantDistance; }
00336 double& FurthestDescendantDistance() { return furthestDescendantDistance; }
00337
00339 void Centroid(arma::vec& centroid) const { centroid = dataset.col(point); }
00340
00342 MetricType& Metric() const { return *metric; }
00343
00344 private:
00346 const arma::mat& dataset;
00347
00349 size_t point;
00350
00352 std::vector<CoverTree*> children;
00353
00355 int scale;
00356
00358 double base;
00359
00361 StatisticType stat;
00362
00364 size_t numDescendants;
00365
00367 CoverTree* parent;
00368
00370 double parentDistance;
00371
00373 double furthestDescendantDistance;
00374
00376 bool localMetric;
00377
00379 MetricType* metric;
00380
00384 void CreateChildren(arma::Col<size_t>& indices,
00385 arma::vec& distances,
00386 size_t nearSetSize,
00387 size_t& farSetSize,
00388 size_t& usedSetSize);
00389
00401 void ComputeDistances(const size_t pointIndex,
00402 const arma::Col<size_t>& indices,
00403 arma::vec& distances,
00404 const size_t pointSetSize);
00419 size_t SplitNearFar(arma::Col<size_t>& indices,
00420 arma::vec& distances,
00421 const double bound,
00422 const size_t pointSetSize);
00423
00443 size_t SortPointSet(arma::Col<size_t>& indices,
00444 arma::vec& distances,
00445 const size_t childFarSetSize,
00446 const size_t childUsedSetSize,
00447 const size_t farSetSize);
00448
00449 void MoveToUsedSet(arma::Col<size_t>& indices,
00450 arma::vec& distances,
00451 size_t& nearSetSize,
00452 size_t& farSetSize,
00453 size_t& usedSetSize,
00454 arma::Col<size_t>& childIndices,
00455 const size_t childFarSetSize,
00456 const size_t childUsedSetSize);
00457 size_t PruneFarSet(arma::Col<size_t>& indices,
00458 arma::vec& distances,
00459 const double bound,
00460 const size_t nearSetSize,
00461 const size_t pointSetSize);
00462
00467 void RemoveNewImplicitNodes();
00468
00469 public:
00473 std::string ToString() const;
00474
00475 size_t DistanceComps() const { return distanceComps; }
00476 size_t& DistanceComps() { return distanceComps; }
00477
00478 private:
00479 size_t distanceComps;
00480 };
00481
00482 };
00483 };
00484
00485
00486 #include "cover_tree_impl.hpp"
00487
00488 #endif