als_update_rules.hpp

Go to the documentation of this file.
00001 
00028 #ifndef __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
00029 #define __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
00030 
00031 #include <mlpack/core.hpp>
00032 
00033 namespace mlpack {
00034 namespace nmf {
00035 
00042 class WAlternatingLeastSquaresRule
00043 {
00044  public:
00045   // Empty constructor required for the WUpdateRule template.
00046   WAlternatingLeastSquaresRule() { }
00047 
00056   template<typename MatType>
00057   inline static void Update(const MatType& V,
00058                             arma::mat& W,
00059                             const arma::mat& H)
00060   {
00061     // The call to inv() sometimes fails; so we are using the psuedoinverse.
00062     // W = (inv(H * H.t()) * H * V.t()).t();
00063     W = V * H.t() * pinv(H * H.t());
00064 
00065     // Set all negative numbers to machine epsilon
00066     for (size_t i = 0; i < W.n_elem; i++)
00067     {
00068       if (W(i) < 0.0)
00069       {
00070         W(i) = 0.0;
00071       }
00072     }
00073   }
00074 };
00075 
00082 class HAlternatingLeastSquaresRule
00083 {
00084  public:
00085   // Empty constructor required for the HUpdateRule template.
00086   HAlternatingLeastSquaresRule() { }
00087 
00096   template<typename MatType>
00097   inline static void Update(const MatType& V,
00098                             const arma::mat& W,
00099                             arma::mat& H)
00100   {
00101     H = pinv(W.t() * W) * W.t() * V;
00102 
00103     // Set all negative numbers to 0.
00104     for (size_t i = 0; i < H.n_elem; i++)
00105     {
00106       if (H(i) < 0.0)
00107       {
00108         H(i) = 0.0;
00109       }
00110     }
00111   }
00112 };
00113 
00114 }; // namespace nmf
00115 }; // namespace mlpack
00116 
00117 #endif

Generated on 13 Aug 2014 for MLPACK by  doxygen 1.6.1