MLPACK
1.0.4
|
00001 00028 #ifndef __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP 00029 #define __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP 00030 00031 #include <mlpack/core.hpp> 00032 00033 namespace mlpack { 00034 namespace nmf { 00035 00043 class WMultiplicativeDivergenceRule 00044 { 00045 public: 00046 // Empty constructor required for the WUpdateRule template. 00047 WMultiplicativeDivergenceRule() { } 00048 00057 inline static void Update(const arma::mat& V, 00058 arma::mat& W, 00059 const arma::mat& H) 00060 { 00061 // Simple implementation left in the header file. 00062 arma::mat t1; 00063 arma::rowvec t2; 00064 00065 t1 = W * H; 00066 for (size_t i = 0; i < W.n_rows; ++i) 00067 { 00068 for (size_t j = 0; j < W.n_cols; ++j) 00069 { 00070 t2 = H.row(j) % V.row(i) / t1.row(i); 00071 W(i, j) = W(i, j) * sum(t2) / sum(H.row(j)); 00072 } 00073 } 00074 } 00075 }; 00076 00084 class HMultiplicativeDivergenceRule 00085 { 00086 public: 00087 // Empty constructor required for the HUpdateRule template. 00088 HMultiplicativeDivergenceRule() { } 00089 00098 inline static void Update(const arma::mat& V, 00099 const arma::mat& W, 00100 arma::mat& H) 00101 { 00102 // Simple implementation left in the header file. 00103 arma::mat t1; 00104 arma::colvec t2; 00105 00106 t1 = W * H; 00107 for (size_t i = 0; i < H.n_rows; i++) 00108 { 00109 for (size_t j = 0; j < H.n_cols; j++) 00110 { 00111 t2 = W.col(i) % V.col(j) / t1.col(j); 00112 H(i,j) = H(i,j) * sum(t2) / sum(W.col(i)); 00113 } 00114 } 00115 } 00116 }; 00117 00118 }; // namespace nmf 00119 }; // namespace mlpack 00120 00121 #endif