MLPACK
1.0.4
|
00001 00028 #ifndef __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP 00029 #define __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP 00030 00031 #include <mlpack/core.hpp> 00032 00033 namespace mlpack { 00034 namespace nmf { 00035 00042 class WAlternatingLeastSquaresRule 00043 { 00044 public: 00045 // Empty constructor required for the WUpdateRule template. 00046 WAlternatingLeastSquaresRule() { } 00047 00056 inline static void Update(const arma::mat& V, 00057 arma::mat& W, 00058 const arma::mat& H) 00059 { 00060 // The call to inv() sometimes fails; so we are using the psuedoinverse. 00061 // W = (inv(H * H.t()) * H * V.t()).t(); 00062 W = V * H.t() * pinv(H * H.t()); 00063 00064 // Set all negative numbers to machine epsilon 00065 for (size_t i = 0; i < W.n_elem; i++) 00066 { 00067 if (W(i) < 0.0) 00068 { 00069 W(i) = 0.0; 00070 } 00071 } 00072 } 00073 }; 00074 00081 class HAlternatingLeastSquaresRule 00082 { 00083 public: 00084 // Empty constructor required for the HUpdateRule template. 00085 HAlternatingLeastSquaresRule() { } 00086 00095 inline static void Update(const arma::mat& V, 00096 const arma::mat& W, 00097 arma::mat& H) 00098 { 00099 H = pinv(W.t() * W) * W.t() * V; 00100 00101 // Set all negative numbers to 0. 00102 for (size_t i = 0; i < H.n_elem; i++) 00103 { 00104 if (H(i) < 0.0) 00105 { 00106 H(i) = 0.0; 00107 } 00108 } 00109 } 00110 }; 00111 00112 }; // namespace nmf 00113 }; // namespace mlpack 00114 00115 #endif