MLPACK
1.0.4
|
00001 00022 #ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP 00023 #define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP 00024 00025 #include <mlpack/core.hpp> 00026 #include <mlpack/core/metrics/lmetric.hpp> 00027 #include "first_point_is_root.hpp" 00028 #include "../statistic.hpp" 00029 00030 namespace mlpack { 00031 namespace tree { 00032 00100 template<typename MetricType = metric::LMetric<2, true>, 00101 typename RootPointPolicy = FirstPointIsRoot, 00102 typename StatisticType = EmptyStatistic> 00103 class CoverTree 00104 { 00105 public: 00106 typedef arma::mat Mat; 00107 00116 CoverTree(const arma::mat& dataset, 00117 const double base = 2.0, 00118 MetricType* metric = NULL); 00119 00149 CoverTree(const arma::mat& dataset, 00150 const double base, 00151 const size_t pointIndex, 00152 const int scale, 00153 const double parentDistance, 00154 arma::Col<size_t>& indices, 00155 arma::vec& distances, 00156 size_t nearSetSize, 00157 size_t& farSetSize, 00158 size_t& usedSetSize, 00159 MetricType& metric = NULL); 00160 00175 CoverTree(const arma::mat& dataset, 00176 const double base, 00177 const size_t pointIndex, 00178 const int scale, 00179 const double parentDistance, 00180 const double furthestDescendantDistance); 00181 00185 ~CoverTree(); 00186 00189 template<typename RuleType> 00190 class SingleTreeTraverser; 00191 00193 template<typename RuleType> 00194 class DualTreeTraverser; 00195 00197 const arma::mat& Dataset() const { return dataset; } 00198 00200 size_t Point() const { return point; } 00202 size_t Point(const size_t) const { return point; } 00203 00204 // Fake 00205 CoverTree* Left() const { return NULL; } 00206 CoverTree* Right() const { return NULL; } 00207 size_t Begin() const { return 0; } 00208 size_t Count() const { return 0; } 00209 size_t End() const { return 0; } 00210 bool IsLeaf() const { return (children.size() == 0); } 00211 size_t NumPoints() const { return 1; } 00212 00214 const CoverTree& Child(const size_t index) const { return *children[index]; } 00216 CoverTree& Child(const size_t index) { return *children[index]; } 00217 00219 size_t NumChildren() const { return children.size(); } 00220 00222 const std::vector<CoverTree*>& Children() const { return children; } 00224 std::vector<CoverTree*>& Children() { return children; } 00225 00227 int Scale() const { return scale; } 00229 int& Scale() { return scale; } 00230 00232 double Base() const { return base; } 00234 double& Base() { return base; } 00235 00237 const StatisticType& Stat() const { return stat; } 00239 StatisticType& Stat() { return stat; } 00240 00242 double MinDistance(const CoverTree* other) const; 00243 00246 double MinDistance(const CoverTree* other, const double distance) const; 00247 00249 double MinDistance(const arma::vec& other) const; 00250 00253 double MinDistance(const arma::vec& other, const double distance) const; 00254 00256 double MaxDistance(const CoverTree* other) const; 00257 00260 double MaxDistance(const CoverTree* other, const double distance) const; 00261 00263 double MaxDistance(const arma::vec& other) const; 00264 00267 double MaxDistance(const arma::vec& other, const double distance) const; 00268 00270 static bool HasSelfChildren() { return true; } 00271 00273 double ParentDistance() const { return parentDistance; } 00275 double& ParentDistance() { return parentDistance; } 00276 00278 double FurthestDescendantDistance() const 00279 { return furthestDescendantDistance; } 00281 double& FurthestDescendantDistance() { return furthestDescendantDistance; } 00282 00283 private: 00285 const arma::mat& dataset; 00286 00288 size_t point; 00289 00291 std::vector<CoverTree*> children; 00292 00294 int scale; 00295 00297 double base; 00298 00300 StatisticType stat; 00301 00303 double parentDistance; 00304 00306 double furthestDescendantDistance; 00307 00319 void ComputeDistances(const size_t pointIndex, 00320 const arma::Col<size_t>& indices, 00321 arma::vec& distances, 00322 const size_t pointSetSize, 00323 MetricType& metric); 00338 size_t SplitNearFar(arma::Col<size_t>& indices, 00339 arma::vec& distances, 00340 const double bound, 00341 const size_t pointSetSize); 00342 00362 size_t SortPointSet(arma::Col<size_t>& indices, 00363 arma::vec& distances, 00364 const size_t childFarSetSize, 00365 const size_t childUsedSetSize, 00366 const size_t farSetSize); 00367 00368 void MoveToUsedSet(arma::Col<size_t>& indices, 00369 arma::vec& distances, 00370 size_t& nearSetSize, 00371 size_t& farSetSize, 00372 size_t& usedSetSize, 00373 arma::Col<size_t>& childIndices, 00374 const size_t childFarSetSize, 00375 const size_t childUsedSetSize); 00376 size_t PruneFarSet(arma::Col<size_t>& indices, 00377 arma::vec& distances, 00378 const double bound, 00379 const size_t nearSetSize, 00380 const size_t pointSetSize); 00381 }; 00382 00383 }; // namespace tree 00384 }; // namespace mlpack 00385 00386 // Include implementation. 00387 #include "cover_tree_impl.hpp" 00388 00389 #endif