MLPACK
1.0.4
|
00001 00038 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 00039 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 00040 00041 #include <mlpack/core.hpp> 00042 #include <vector> 00043 #include <string> 00044 00045 #include <mlpack/core/metrics/lmetric.hpp> 00046 #include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp> 00047 00048 namespace mlpack { 00049 namespace neighbor { 00050 00058 template<typename SortPolicy = NearestNeighborSort> 00059 class LSHSearch 00060 { 00061 public: 00083 LSHSearch(const arma::mat& referenceSet, 00084 const arma::mat& querySet, 00085 const size_t numProj, 00086 const size_t numTables, 00087 const double hashWidth = 0.0, 00088 const size_t secondHashSize = 99901, 00089 const size_t bucketSize = 500); 00090 00111 LSHSearch(const arma::mat& referenceSet, 00112 const size_t numProj, 00113 const size_t numTables, 00114 const double hashWidth = 0.0, 00115 const size_t secondHashSize = 99901, 00116 const size_t bucketSize = 500); 00117 00136 void Search(const size_t k, 00137 arma::Mat<size_t>& resultingNeighbors, 00138 arma::mat& distances, 00139 const size_t numTablesToSearch = 0); 00140 00141 private: 00155 void BuildHash(); 00156 00168 void ReturnIndicesFromTable(const size_t queryIndex, 00169 arma::uvec& referenceIndices, 00170 size_t numTablesToSearch); 00171 00179 double BaseCase(const size_t queryIndex, const size_t referenceIndex); 00180 00193 void InsertNeighbor(const size_t queryIndex, const size_t pos, 00194 const size_t neighbor, const double distance); 00195 00196 private: 00198 const arma::mat& referenceSet; 00199 00201 const arma::mat& querySet; 00202 00204 const size_t numProj; 00205 00207 const size_t numTables; 00208 00210 std::vector<arma::mat> projections; // should be [numProj x dims] x numTables 00211 00213 arma::mat offsets; // should be numProj x numTables 00214 00216 double hashWidth; 00217 00219 const size_t secondHashSize; 00220 00222 arma::vec secondHashWeights; 00223 00225 const size_t bucketSize; 00226 00228 metric::SquaredEuclideanDistance metric; 00229 00231 arma::Mat<size_t> secondHashTable; 00232 00235 arma::Col<size_t> bucketContentSize; 00236 00239 arma::Col<size_t> bucketRowInHashTable; 00240 00242 arma::mat* distancePtr; 00243 00245 arma::Mat<size_t>* neighborPtr; 00246 }; // class LSHSearch 00247 00248 }; // namespace neighbor 00249 }; // namespace mlpack 00250 00251 // Include implementation. 00252 #include "lsh_search_impl.hpp" 00253 00254 #endif