MLPACK  1.0.4
dtb.hpp
Go to the documentation of this file.
00001 
00035 #ifndef __MLPACK_METHODS_EMST_DTB_HPP
00036 #define __MLPACK_METHODS_EMST_DTB_HPP
00037 
00038 #include "edge_pair.hpp"
00039 
00040 #include <mlpack/core.hpp>
00041 #include <mlpack/core/metrics/lmetric.hpp>
00042 
00043 #include <mlpack/core/tree/binary_space_tree.hpp>
00044 
00045 namespace mlpack {
00046 namespace emst  {
00047 
00052 class DTBStat
00053 {
00054  private:
00057   double maxNeighborDistance;
00062   int componentMembership;
00063 
00064  public:
00068   DTBStat();
00069 
00073   template<typename MatType>
00074   DTBStat(const MatType& dataset, const size_t start, const size_t count);
00075 
00079   template<typename MatType>
00080   DTBStat(const MatType& dataset, const size_t start, const size_t count,
00081           const DTBStat& leftStat, const DTBStat& rightStat);
00082 
00084   double MaxNeighborDistance() const { return maxNeighborDistance; }
00086   double& MaxNeighborDistance() { return maxNeighborDistance; }
00087 
00089   int ComponentMembership() const { return componentMembership; }
00091   int& ComponentMembership() { return componentMembership; }
00092 
00093 }; // class DTBStat
00094 
00133 template<
00134   typename MetricType = metric::SquaredEuclideanDistance,
00135   typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
00136 >
00137 class DualTreeBoruvka
00138 {
00139  private:
00141   typename TreeType::Mat dataCopy;
00143   typename TreeType::Mat& data;
00144 
00146   TreeType* tree;
00148   bool ownTree;
00149 
00151   bool naive;
00152 
00154   std::vector<EdgePair> edges; // We must use vector with non-numerical types.
00155 
00157   UnionFind connections;
00158 
00160   std::vector<size_t> oldFromNew;
00162   arma::Col<size_t> neighborsInComponent;
00164   arma::Col<size_t> neighborsOutComponent;
00166   arma::vec neighborsDistances;
00167 
00169   double totalDist;
00170   
00172   MetricType metric;
00173 
00174   // For sorting the edge list after the computation.
00175   struct SortEdgesHelper
00176   {
00177     bool operator()(const EdgePair& pairA, const EdgePair& pairB)
00178     {
00179       return (pairA.Distance() < pairB.Distance());
00180     }
00181   } SortFun;
00182 
00183  public:
00192   DualTreeBoruvka(const typename TreeType::Mat& dataset,
00193                   const bool naive = false,
00194                   const size_t leafSize = 1,
00195                   const MetricType metric = MetricType());
00196 
00214   DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
00215                   const MetricType metric = MetricType());
00216 
00220   ~DualTreeBoruvka();
00221 
00231   void ComputeMST(arma::mat& results);
00232 
00233  private:
00237   void AddEdge(const size_t e1, const size_t e2, const double distance);
00238 
00242   void AddAllEdges();
00243 
00247   void EmitResults(arma::mat& results);
00248 
00253   void CleanupHelper(TreeType* tree);
00254 
00258   void Cleanup();
00259 
00260 }; // class DualTreeBoruvka
00261 
00262 }; // namespace emst
00263 }; // namespace mlpack
00264 
00265 #include "dtb_impl.hpp"
00266 
00267 #endif // __MLPACK_METHODS_EMST_DTB_HPP