MLPACK  1.0.4
cover_tree.hpp
Go to the documentation of this file.
00001 
00022 #ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
00023 #define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
00024 
00025 #include <mlpack/core.hpp>
00026 #include <mlpack/core/metrics/lmetric.hpp>
00027 #include "first_point_is_root.hpp"
00028 #include "../statistic.hpp"
00029 
00030 namespace mlpack {
00031 namespace tree {
00032 
00100 template<typename MetricType = metric::LMetric<2, true>,
00101          typename RootPointPolicy = FirstPointIsRoot,
00102          typename StatisticType = EmptyStatistic>
00103 class CoverTree
00104 {
00105  public:
00106   typedef arma::mat Mat;
00107 
00116   CoverTree(const arma::mat& dataset,
00117             const double base = 2.0,
00118             MetricType* metric = NULL);
00119 
00149   CoverTree(const arma::mat& dataset,
00150             const double base,
00151             const size_t pointIndex,
00152             const int scale,
00153             const double parentDistance,
00154             arma::Col<size_t>& indices,
00155             arma::vec& distances,
00156             size_t nearSetSize,
00157             size_t& farSetSize,
00158             size_t& usedSetSize,
00159             MetricType& metric = NULL);
00160 
00175   CoverTree(const arma::mat& dataset,
00176             const double base,
00177             const size_t pointIndex,
00178             const int scale,
00179             const double parentDistance,
00180             const double furthestDescendantDistance);
00181 
00185   ~CoverTree();
00186 
00189   template<typename RuleType>
00190   class SingleTreeTraverser;
00191 
00193   template<typename RuleType>
00194   class DualTreeTraverser;
00195 
00197   const arma::mat& Dataset() const { return dataset; }
00198 
00200   size_t Point() const { return point; }
00202   size_t Point(const size_t) const { return point; }
00203 
00204   // Fake
00205   CoverTree* Left() const { return NULL; }
00206   CoverTree* Right() const { return NULL; }
00207   size_t Begin() const { return 0; }
00208   size_t Count() const { return 0; }
00209   size_t End() const { return 0; }
00210   bool IsLeaf() const { return (children.size() == 0); }
00211   size_t NumPoints() const { return 1; }
00212 
00214   const CoverTree& Child(const size_t index) const { return *children[index]; }
00216   CoverTree& Child(const size_t index) { return *children[index]; }
00217 
00219   size_t NumChildren() const { return children.size(); }
00220 
00222   const std::vector<CoverTree*>& Children() const { return children; }
00224   std::vector<CoverTree*>& Children() { return children; }
00225 
00227   int Scale() const { return scale; }
00229   int& Scale() { return scale; }
00230 
00232   double Base() const { return base; }
00234   double& Base() { return base; }
00235 
00237   const StatisticType& Stat() const { return stat; }
00239   StatisticType& Stat() { return stat; }
00240 
00242   double MinDistance(const CoverTree* other) const;
00243 
00246   double MinDistance(const CoverTree* other, const double distance) const;
00247 
00249   double MinDistance(const arma::vec& other) const;
00250 
00253   double MinDistance(const arma::vec& other, const double distance) const;
00254 
00256   double MaxDistance(const CoverTree* other) const;
00257 
00260   double MaxDistance(const CoverTree* other, const double distance) const;
00261 
00263   double MaxDistance(const arma::vec& other) const;
00264 
00267   double MaxDistance(const arma::vec& other, const double distance) const;
00268 
00270   static bool HasSelfChildren() { return true; }
00271 
00273   double ParentDistance() const { return parentDistance; }
00275   double& ParentDistance() { return parentDistance; }
00276 
00278   double FurthestDescendantDistance() const
00279   { return furthestDescendantDistance; }
00281   double& FurthestDescendantDistance() { return furthestDescendantDistance; }
00282 
00283  private:
00285   const arma::mat& dataset;
00286 
00288   size_t point;
00289 
00291   std::vector<CoverTree*> children;
00292 
00294   int scale;
00295 
00297   double base;
00298 
00300   StatisticType stat;
00301 
00303   double parentDistance;
00304 
00306   double furthestDescendantDistance;
00307 
00319   void ComputeDistances(const size_t pointIndex,
00320                         const arma::Col<size_t>& indices,
00321                         arma::vec& distances,
00322                         const size_t pointSetSize,
00323                         MetricType& metric);
00338   size_t SplitNearFar(arma::Col<size_t>& indices,
00339                       arma::vec& distances,
00340                       const double bound,
00341                       const size_t pointSetSize);
00342 
00362   size_t SortPointSet(arma::Col<size_t>& indices,
00363                       arma::vec& distances,
00364                       const size_t childFarSetSize,
00365                       const size_t childUsedSetSize,
00366                       const size_t farSetSize);
00367 
00368   void MoveToUsedSet(arma::Col<size_t>& indices,
00369                      arma::vec& distances,
00370                      size_t& nearSetSize,
00371                      size_t& farSetSize,
00372                      size_t& usedSetSize,
00373                      arma::Col<size_t>& childIndices,
00374                      const size_t childFarSetSize,
00375                      const size_t childUsedSetSize);
00376   size_t PruneFarSet(arma::Col<size_t>& indices,
00377                      arma::vec& distances,
00378                      const double bound,
00379                      const size_t nearSetSize,
00380                      const size_t pointSetSize);
00381 };
00382 
00383 }; // namespace tree
00384 }; // namespace mlpack
00385 
00386 // Include implementation.
00387 #include "cover_tree_impl.hpp"
00388 
00389 #endif