MLPACK  1.0.4
als_update_rules.hpp
Go to the documentation of this file.
00001 
00028 #ifndef __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
00029 #define __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
00030 
00031 #include <mlpack/core.hpp>
00032 
00033 namespace mlpack {
00034 namespace nmf {
00035 
00042 class WAlternatingLeastSquaresRule
00043 {
00044  public:
00045   // Empty constructor required for the WUpdateRule template.
00046   WAlternatingLeastSquaresRule() { }
00047 
00056   inline static void Update(const arma::mat& V,
00057                             arma::mat& W,
00058                             const arma::mat& H)
00059   {
00060     // The call to inv() sometimes fails; so we are using the psuedoinverse.
00061     // W = (inv(H * H.t()) * H * V.t()).t();
00062     W = V * H.t() * pinv(H * H.t());
00063 
00064     // Set all negative numbers to machine epsilon
00065     for (size_t i = 0; i < W.n_elem; i++)
00066     {
00067       if (W(i) < 0.0)
00068       {
00069         W(i) = 0.0;
00070       }
00071     }
00072   }
00073 };
00074 
00081 class HAlternatingLeastSquaresRule
00082 {
00083  public:
00084   // Empty constructor required for the HUpdateRule template.
00085   HAlternatingLeastSquaresRule() { }
00086 
00095   inline static void Update(const arma::mat& V,
00096                             const arma::mat& W,
00097                             arma::mat& H)
00098   {
00099     H = pinv(W.t() * W) * W.t() * V;
00100 
00101     // Set all negative numbers to 0.
00102     for (size_t i = 0; i < H.n_elem; i++)
00103     {
00104       if (H(i) < 0.0)
00105       {
00106         H(i) = 0.0;
00107       }
00108     }
00109   }
00110 };
00111 
00112 }; // namespace nmf
00113 }; // namespace mlpack
00114 
00115 #endif