MLPACK
1.0.4
|
00001 00023 #ifndef __MLPACK_METHODS_DET_DTREE_HPP 00024 #define __MLPACK_METHODS_DET_DTREE_HPP 00025 00026 #include <assert.h> 00027 00028 #include <mlpack/core.hpp> 00029 00030 namespace mlpack { 00031 namespace det { 00032 00056 class DTree 00057 { 00058 public: 00062 DTree(); 00063 00072 DTree(const arma::vec& maxVals, 00073 const arma::vec& minVals, 00074 const size_t totalPoints); 00075 00084 DTree(arma::mat& data); 00085 00098 DTree(const arma::vec& maxVals, 00099 const arma::vec& minVals, 00100 const size_t start, 00101 const size_t end, 00102 const double logNegError); 00103 00115 DTree(const arma::vec& maxVals, 00116 const arma::vec& minVals, 00117 const size_t totalPoints, 00118 const size_t start, 00119 const size_t end); 00120 00122 ~DTree(); 00123 00134 double Grow(arma::mat& data, 00135 arma::Col<size_t>& oldFromNew, 00136 const bool useVolReg = false, 00137 const size_t maxLeafSize = 10, 00138 const size_t minLeafSize = 5); 00139 00148 double PruneAndUpdate(const double oldAlpha, 00149 const size_t points, 00150 const bool useVolReg = false); 00151 00157 double ComputeValue(const arma::vec& query) const; 00158 00166 void WriteTree(FILE *fp, const size_t level = 0) const; 00167 00175 int TagTree(const int tag = 0); 00176 00183 int FindBucket(const arma::vec& query) const; 00184 00190 void ComputeVariableImportance(arma::vec& importances) const; 00191 00198 double LogNegativeError(const size_t totalPoints) const; 00199 00203 bool WithinRange(const arma::vec& query) const; 00204 00205 private: 00206 // The indices in the complete set of points 00207 // (after all forms of swapping in the original data 00208 // matrix to align all the points in a node 00209 // consecutively in the matrix. The 'old_from_new' array 00210 // maps the points back to their original indices. 00211 00214 size_t start; 00217 size_t end; 00218 00220 arma::vec maxVals; 00222 arma::vec minVals; 00223 00225 size_t splitDim; 00226 00228 double splitValue; 00229 00231 double logNegError; 00232 00234 double subtreeLeavesLogNegError; 00235 00237 size_t subtreeLeaves; 00238 00240 bool root; 00241 00243 double ratio; 00244 00246 double logVolume; 00247 00249 int bucketTag; 00250 00252 double alphaUpper; 00253 00255 DTree* left; 00257 DTree* right; 00258 00259 public: 00261 size_t Start() const { return start; } 00263 size_t End() const { return end; } 00265 size_t SplitDim() const { return splitDim; } 00267 double SplitValue() const { return splitValue; } 00269 double LogNegError() const { return logNegError; } 00271 double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; } 00273 size_t SubtreeLeaves() const { return subtreeLeaves; } 00276 double Ratio() const { return ratio; } 00278 double LogVolume() const { return logVolume; } 00280 DTree* Left() const { return left; } 00282 DTree* Right() const { return right; } 00284 bool Root() const { return root; } 00286 double AlphaUpper() const { return alphaUpper; } 00287 00289 const arma::vec& MaxVals() const { return maxVals; } 00291 arma::vec& MaxVals() { return maxVals; } 00292 00294 const arma::vec& MinVals() const { return minVals; } 00296 arma::vec& MinVals() { return minVals; } 00297 00298 private: 00299 00300 // Utility methods. 00301 00305 bool FindSplit(const arma::mat& data, 00306 size_t& splitDim, 00307 double& splitValue, 00308 double& leftError, 00309 double& rightError, 00310 const size_t maxLeafSize = 10, 00311 const size_t minLeafSize = 5) const; 00312 00316 size_t SplitData(arma::mat& data, 00317 const size_t splitDim, 00318 const double splitValue, 00319 arma::Col<size_t>& oldFromNew) const; 00320 00321 }; 00322 00323 }; // namespace det 00324 }; // namespace mlpack 00325 00326 #endif // __MLPACK_METHODS_DET_DTREE_HPP