MLPACK  1.0.4
neighbor_search.hpp
Go to the documentation of this file.
00001 
00023 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
00024 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
00025 
00026 #include <mlpack/core.hpp>
00027 #include <vector>
00028 #include <string>
00029 
00030 #include <mlpack/core/tree/binary_space_tree.hpp>
00031 
00032 #include <mlpack/core/metrics/lmetric.hpp>
00033 #include "sort_policies/nearest_neighbor_sort.hpp"
00034 
00035 namespace mlpack {
00036 namespace neighbor  {
00039 
00044 template<typename SortPolicy>
00045 class QueryStat
00046 {
00047  private:
00049   double bound;
00050 
00051  public:
00056   QueryStat() : bound(SortPolicy::WorstDistance()) { }
00057 
00061   template<typename MatType>
00062   QueryStat(const MatType& /* dataset */, const size_t /* begin */, const size_t /* count */)
00063       : bound(SortPolicy::WorstDistance()) { }
00064 
00068   template<typename MatType>
00069   QueryStat(const MatType& /* dataset */,
00070             const size_t /* begin */,
00071             const size_t /* count */,
00072             const QueryStat& /* leftStat */,
00073             const QueryStat& /* rightStat */)
00074       : bound(SortPolicy::WorstDistance()) { }
00075 
00077   double Bound() const { return bound; }
00079   double& Bound() { return bound; }
00080 };
00081 
00100 template<typename SortPolicy = NearestNeighborSort,
00101          typename MetricType = mlpack::metric::SquaredEuclideanDistance,
00102          typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
00103                                                    QueryStat<SortPolicy> > >
00104 class NeighborSearch
00105 {
00106  public:
00127   NeighborSearch(const typename TreeType::Mat& referenceSet,
00128                  const typename TreeType::Mat& querySet,
00129                  const bool naive = false,
00130                  const bool singleMode = false,
00131                  const size_t leafSize = 20,
00132                  const MetricType metric = MetricType());
00133 
00155   NeighborSearch(const typename TreeType::Mat& referenceSet,
00156                  const bool naive = false,
00157                  const bool singleMode = false,
00158                  const size_t leafSize = 20,
00159                  const MetricType metric = MetricType());
00160 
00190   NeighborSearch(TreeType* referenceTree,
00191                  TreeType* queryTree,
00192                  const typename TreeType::Mat& referenceSet,
00193                  const typename TreeType::Mat& querySet,
00194                  const bool singleMode = false,
00195                  const MetricType metric = MetricType());
00196 
00224   NeighborSearch(TreeType* referenceTree,
00225                  const typename TreeType::Mat& referenceSet,
00226                  const bool singleMode = false,
00227                  const MetricType metric = MetricType());
00228 
00229 
00234   ~NeighborSearch();
00235 
00248   void Search(const size_t k,
00249               arma::Mat<size_t>& resultingNeighbors,
00250               arma::mat& distances);
00251 
00252  private:
00255   arma::mat referenceCopy;
00257   arma::mat queryCopy;
00258 
00260   const arma::mat& referenceSet;
00262   const arma::mat& querySet;
00263 
00265   TreeType* referenceTree;
00267   TreeType* queryTree;
00268 
00270   bool ownReferenceTree;
00272   bool ownQueryTree;
00273 
00275   bool naive;
00277   bool singleMode;
00278 
00280   MetricType metric;
00281 
00283   std::vector<size_t> oldFromNewReferences;
00285   std::vector<size_t> oldFromNewQueries;
00286 
00288   size_t numberOfPrunes;
00289 }; // class NeighborSearch
00290 
00291 }; // namespace neighbor
00292 }; // namespace mlpack
00293 
00294 // Include implementation.
00295 #include "neighbor_search_impl.hpp"
00296 
00297 // Include convenience typedefs.
00298 #include "typedef.hpp"
00299 
00300 #endif