MLPACK  1.0.4
dtree.hpp
Go to the documentation of this file.
00001 
00023 #ifndef __MLPACK_METHODS_DET_DTREE_HPP
00024 #define __MLPACK_METHODS_DET_DTREE_HPP
00025 
00026 #include <assert.h>
00027 
00028 #include <mlpack/core.hpp>
00029 
00030 namespace mlpack {
00031 namespace det  {
00032 
00056 class DTree
00057 {
00058  public:
00062   DTree();
00063 
00072   DTree(const arma::vec& maxVals,
00073         const arma::vec& minVals,
00074         const size_t totalPoints);
00075 
00084   DTree(arma::mat& data);
00085 
00098   DTree(const arma::vec& maxVals,
00099         const arma::vec& minVals,
00100         const size_t start,
00101         const size_t end,
00102         const double logNegError);
00103 
00115   DTree(const arma::vec& maxVals,
00116         const arma::vec& minVals,
00117         const size_t totalPoints,
00118         const size_t start,
00119         const size_t end);
00120 
00122   ~DTree();
00123 
00134   double Grow(arma::mat& data,
00135               arma::Col<size_t>& oldFromNew,
00136               const bool useVolReg = false,
00137               const size_t maxLeafSize = 10,
00138               const size_t minLeafSize = 5);
00139 
00148   double PruneAndUpdate(const double oldAlpha,
00149                         const size_t points,
00150                         const bool useVolReg = false);
00151 
00157   double ComputeValue(const arma::vec& query) const;
00158 
00166   void WriteTree(FILE *fp, const size_t level = 0) const;
00167 
00175   int TagTree(const int tag = 0);
00176 
00183   int FindBucket(const arma::vec& query) const;
00184 
00190   void ComputeVariableImportance(arma::vec& importances) const;
00191 
00198   double LogNegativeError(const size_t totalPoints) const;
00199 
00203   bool WithinRange(const arma::vec& query) const;
00204 
00205  private:
00206   // The indices in the complete set of points
00207   // (after all forms of swapping in the original data
00208   // matrix to align all the points in a node
00209   // consecutively in the matrix. The 'old_from_new' array
00210   // maps the points back to their original indices.
00211 
00214   size_t start;
00217   size_t end;
00218 
00220   arma::vec maxVals;
00222   arma::vec minVals;
00223 
00225   size_t splitDim;
00226 
00228   double splitValue;
00229 
00231   double logNegError;
00232 
00234   double subtreeLeavesLogNegError;
00235 
00237   size_t subtreeLeaves;
00238 
00240   bool root;
00241 
00243   double ratio;
00244 
00246   double logVolume;
00247 
00249   int bucketTag;
00250 
00252   double alphaUpper;
00253 
00255   DTree* left;
00257   DTree* right;
00258 
00259  public:
00261   size_t Start() const { return start; }
00263   size_t End() const { return end; }
00265   size_t SplitDim() const { return splitDim; }
00267   double SplitValue() const { return splitValue; }
00269   double LogNegError() const { return logNegError; }
00271   double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
00273   size_t SubtreeLeaves() const { return subtreeLeaves; }
00276   double Ratio() const { return ratio; }
00278   double LogVolume() const { return logVolume; }
00280   DTree* Left() const { return left; }
00282   DTree* Right() const { return right; }
00284   bool Root() const { return root; }
00286   double AlphaUpper() const { return alphaUpper; }
00287 
00289   const arma::vec& MaxVals() const { return maxVals; }
00291   arma::vec& MaxVals() { return maxVals; }
00292 
00294   const arma::vec& MinVals() const { return minVals; }
00296   arma::vec& MinVals() { return minVals; }
00297 
00298  private:
00299 
00300   // Utility methods.
00301 
00305   bool FindSplit(const arma::mat& data,
00306                  size_t& splitDim,
00307                  double& splitValue,
00308                  double& leftError,
00309                  double& rightError,
00310                  const size_t maxLeafSize = 10,
00311                  const size_t minLeafSize = 5) const;
00312 
00316   size_t SplitData(arma::mat& data,
00317                    const size_t splitDim,
00318                    const double splitValue,
00319                    arma::Col<size_t>& oldFromNew) const;
00320 
00321 };
00322 
00323 }; // namespace det
00324 }; // namespace mlpack
00325 
00326 #endif // __MLPACK_METHODS_DET_DTREE_HPP