MLPACK  1.0.4
mult_div_update_rules.hpp
Go to the documentation of this file.
00001 
00028 #ifndef __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
00029 #define __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
00030 
00031 #include <mlpack/core.hpp>
00032 
00033 namespace mlpack {
00034 namespace nmf {
00035 
00043 class WMultiplicativeDivergenceRule
00044 {
00045  public:
00046   // Empty constructor required for the WUpdateRule template.
00047   WMultiplicativeDivergenceRule() { }
00048 
00057   inline static void Update(const arma::mat& V,
00058                             arma::mat& W,
00059                             const arma::mat& H)
00060   {
00061     // Simple implementation left in the header file.
00062     arma::mat t1;
00063     arma::rowvec t2;
00064 
00065     t1 = W * H;
00066     for (size_t i = 0; i < W.n_rows; ++i)
00067     {
00068       for (size_t j = 0; j < W.n_cols; ++j)
00069       {
00070         t2 = H.row(j) % V.row(i) / t1.row(i);
00071         W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
00072       }
00073     }
00074   }
00075 };
00076 
00084 class HMultiplicativeDivergenceRule
00085 {
00086  public:
00087   // Empty constructor required for the HUpdateRule template.
00088   HMultiplicativeDivergenceRule() { }
00089 
00098   inline static void Update(const arma::mat& V,
00099                             const arma::mat& W,
00100                             arma::mat& H)
00101   {
00102     // Simple implementation left in the header file.
00103     arma::mat t1;
00104     arma::colvec t2;
00105 
00106     t1 = W * H;
00107     for (size_t i = 0; i < H.n_rows; i++)
00108     {
00109       for (size_t j = 0; j < H.n_cols; j++)
00110       {
00111         t2 = W.col(i) % V.col(j) / t1.col(j);
00112         H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
00113       }
00114     }
00115   }
00116 };
00117 
00118 }; // namespace nmf
00119 }; // namespace mlpack
00120 
00121 #endif