MLPACK
1.0.4
|
00001 00023 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP 00024 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP 00025 00026 #include <mlpack/core.hpp> 00027 #include <vector> 00028 #include <string> 00029 00030 #include <mlpack/core/tree/binary_space_tree.hpp> 00031 00032 #include <mlpack/core/metrics/lmetric.hpp> 00033 #include "sort_policies/nearest_neighbor_sort.hpp" 00034 00035 namespace mlpack { 00036 namespace neighbor { 00039 00044 template<typename SortPolicy> 00045 class QueryStat 00046 { 00047 private: 00049 double bound; 00050 00051 public: 00056 QueryStat() : bound(SortPolicy::WorstDistance()) { } 00057 00061 template<typename MatType> 00062 QueryStat(const MatType& /* dataset */, const size_t /* begin */, const size_t /* count */) 00063 : bound(SortPolicy::WorstDistance()) { } 00064 00068 template<typename MatType> 00069 QueryStat(const MatType& /* dataset */, 00070 const size_t /* begin */, 00071 const size_t /* count */, 00072 const QueryStat& /* leftStat */, 00073 const QueryStat& /* rightStat */) 00074 : bound(SortPolicy::WorstDistance()) { } 00075 00077 double Bound() const { return bound; } 00079 double& Bound() { return bound; } 00080 }; 00081 00100 template<typename SortPolicy = NearestNeighborSort, 00101 typename MetricType = mlpack::metric::SquaredEuclideanDistance, 00102 typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, 00103 QueryStat<SortPolicy> > > 00104 class NeighborSearch 00105 { 00106 public: 00127 NeighborSearch(const typename TreeType::Mat& referenceSet, 00128 const typename TreeType::Mat& querySet, 00129 const bool naive = false, 00130 const bool singleMode = false, 00131 const size_t leafSize = 20, 00132 const MetricType metric = MetricType()); 00133 00155 NeighborSearch(const typename TreeType::Mat& referenceSet, 00156 const bool naive = false, 00157 const bool singleMode = false, 00158 const size_t leafSize = 20, 00159 const MetricType metric = MetricType()); 00160 00190 NeighborSearch(TreeType* referenceTree, 00191 TreeType* queryTree, 00192 const typename TreeType::Mat& referenceSet, 00193 const typename TreeType::Mat& querySet, 00194 const bool singleMode = false, 00195 const MetricType metric = MetricType()); 00196 00224 NeighborSearch(TreeType* referenceTree, 00225 const typename TreeType::Mat& referenceSet, 00226 const bool singleMode = false, 00227 const MetricType metric = MetricType()); 00228 00229 00234 ~NeighborSearch(); 00235 00248 void Search(const size_t k, 00249 arma::Mat<size_t>& resultingNeighbors, 00250 arma::mat& distances); 00251 00252 private: 00255 arma::mat referenceCopy; 00257 arma::mat queryCopy; 00258 00260 const arma::mat& referenceSet; 00262 const arma::mat& querySet; 00263 00265 TreeType* referenceTree; 00267 TreeType* queryTree; 00268 00270 bool ownReferenceTree; 00272 bool ownQueryTree; 00273 00275 bool naive; 00277 bool singleMode; 00278 00280 MetricType metric; 00281 00283 std::vector<size_t> oldFromNewReferences; 00285 std::vector<size_t> oldFromNewQueries; 00286 00288 size_t numberOfPrunes; 00289 }; // class NeighborSearch 00290 00291 }; // namespace neighbor 00292 }; // namespace mlpack 00293 00294 // Include implementation. 00295 #include "neighbor_search_impl.hpp" 00296 00297 // Include convenience typedefs. 00298 #include "typedef.hpp" 00299 00300 #endif