MLPACK
1.0.4
|
00001 00035 #ifndef __MLPACK_METHODS_EMST_DTB_HPP 00036 #define __MLPACK_METHODS_EMST_DTB_HPP 00037 00038 #include "edge_pair.hpp" 00039 00040 #include <mlpack/core.hpp> 00041 #include <mlpack/core/metrics/lmetric.hpp> 00042 00043 #include <mlpack/core/tree/binary_space_tree.hpp> 00044 00045 namespace mlpack { 00046 namespace emst { 00047 00052 class DTBStat 00053 { 00054 private: 00057 double maxNeighborDistance; 00062 int componentMembership; 00063 00064 public: 00068 DTBStat(); 00069 00073 template<typename MatType> 00074 DTBStat(const MatType& dataset, const size_t start, const size_t count); 00075 00079 template<typename MatType> 00080 DTBStat(const MatType& dataset, const size_t start, const size_t count, 00081 const DTBStat& leftStat, const DTBStat& rightStat); 00082 00084 double MaxNeighborDistance() const { return maxNeighborDistance; } 00086 double& MaxNeighborDistance() { return maxNeighborDistance; } 00087 00089 int ComponentMembership() const { return componentMembership; } 00091 int& ComponentMembership() { return componentMembership; } 00092 00093 }; // class DTBStat 00094 00133 template< 00134 typename MetricType = metric::SquaredEuclideanDistance, 00135 typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> 00136 > 00137 class DualTreeBoruvka 00138 { 00139 private: 00141 typename TreeType::Mat dataCopy; 00143 typename TreeType::Mat& data; 00144 00146 TreeType* tree; 00148 bool ownTree; 00149 00151 bool naive; 00152 00154 std::vector<EdgePair> edges; // We must use vector with non-numerical types. 00155 00157 UnionFind connections; 00158 00160 std::vector<size_t> oldFromNew; 00162 arma::Col<size_t> neighborsInComponent; 00164 arma::Col<size_t> neighborsOutComponent; 00166 arma::vec neighborsDistances; 00167 00169 double totalDist; 00170 00172 MetricType metric; 00173 00174 // For sorting the edge list after the computation. 00175 struct SortEdgesHelper 00176 { 00177 bool operator()(const EdgePair& pairA, const EdgePair& pairB) 00178 { 00179 return (pairA.Distance() < pairB.Distance()); 00180 } 00181 } SortFun; 00182 00183 public: 00192 DualTreeBoruvka(const typename TreeType::Mat& dataset, 00193 const bool naive = false, 00194 const size_t leafSize = 1, 00195 const MetricType metric = MetricType()); 00196 00214 DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset, 00215 const MetricType metric = MetricType()); 00216 00220 ~DualTreeBoruvka(); 00221 00231 void ComputeMST(arma::mat& results); 00232 00233 private: 00237 void AddEdge(const size_t e1, const size_t e2, const double distance); 00238 00242 void AddAllEdges(); 00243 00247 void EmitResults(arma::mat& results); 00248 00253 void CleanupHelper(TreeType* tree); 00254 00258 void Cleanup(); 00259 00260 }; // class DualTreeBoruvka 00261 00262 }; // namespace emst 00263 }; // namespace mlpack 00264 00265 #include "dtb_impl.hpp" 00266 00267 #endif // __MLPACK_METHODS_EMST_DTB_HPP