MLPACK
1.0.4
|
00001 00034 #ifndef __MLPACK_METHODS_LARS_LARS_HPP 00035 #define __MLPACK_METHODS_LARS_LARS_HPP 00036 00037 #include <armadillo> 00038 #include <mlpack/core.hpp> 00039 00040 namespace mlpack { 00041 namespace regression { 00042 00043 // beta is the estimator 00044 // yHat is the prediction from the current estimator 00045 00100 class LARS 00101 { 00102 public: 00113 LARS(const bool useCholesky, 00114 const double lambda1 = 0.0, 00115 const double lambda2 = 0.0, 00116 const double tolerance = 1e-16); 00117 00130 LARS(const bool useCholesky, 00131 const arma::mat& gramMatrix, 00132 const double lambda1 = 0.0, 00133 const double lambda2 = 0.0, 00134 const double tolerance = 1e-16); 00135 00150 void Regress(const arma::mat& data, 00151 const arma::vec& responses, 00152 arma::vec& beta, 00153 const bool rowMajor = false); 00154 00156 const std::vector<size_t>& ActiveSet() const { return activeSet; } 00157 00160 const std::vector<arma::vec>& BetaPath() const { return betaPath; } 00161 00164 const std::vector<double>& LambdaPath() const { return lambdaPath; } 00165 00167 const arma::mat& MatUtriCholFactor() const { return matUtriCholFactor; } 00168 00169 private: 00171 arma::mat matGramInternal; 00172 00174 const arma::mat& matGram; 00175 00177 arma::mat matUtriCholFactor; 00178 00180 bool useCholesky; 00181 00183 bool lasso; 00185 double lambda1; 00186 00188 bool elasticNet; 00190 double lambda2; 00191 00193 double tolerance; 00194 00196 std::vector<arma::vec> betaPath; 00197 00199 std::vector<double> lambdaPath; 00200 00202 std::vector<size_t> activeSet; 00203 00205 std::vector<bool> isActive; 00206 00212 void Deactivate(const size_t activeVarInd); 00213 00219 void Activate(const size_t varInd); 00220 00221 // compute "equiangular" direction in output space 00222 void ComputeYHatDirection(const arma::mat& matX, 00223 const arma::vec& betaDirection, 00224 arma::vec& yHatDirection); 00225 00226 // interpolate to compute last solution vector 00227 void InterpolateBeta(); 00228 00229 void CholeskyInsert(const arma::vec& newX, const arma::mat& X); 00230 00231 void CholeskyInsert(double sqNormNewX, const arma::vec& newGramCol); 00232 00233 void GivensRotate(const arma::vec::fixed<2>& x, 00234 arma::vec::fixed<2>& rotatedX, 00235 arma::mat& G); 00236 00237 void CholeskyDelete(const size_t colToKill); 00238 }; 00239 00240 }; // namespace regression 00241 }; // namespace mlpack 00242 00243 #endif