mult_div_update_rules.hpp
Go to the documentation of this file.00001
00028 #ifndef __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
00029 #define __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
00030
00031 #include <mlpack/core.hpp>
00032
00033 namespace mlpack {
00034 namespace nmf {
00035
00043 class WMultiplicativeDivergenceRule
00044 {
00045 public:
00046
00047 WMultiplicativeDivergenceRule() { }
00048
00057 template<typename MatType>
00058 inline static void Update(const MatType& V,
00059 arma::mat& W,
00060 const arma::mat& H)
00061 {
00062
00063 arma::mat t1;
00064 arma::rowvec t2;
00065
00066 t1 = W * H;
00067 for (size_t i = 0; i < W.n_rows; ++i)
00068 {
00069 for (size_t j = 0; j < W.n_cols; ++j)
00070 {
00071
00072
00073
00074
00075 t2.set_size(H.n_cols);
00076 for (size_t k = 0; k < t2.n_elem; ++k)
00077 {
00078 t2(k) = H(j, k) * V(i, k) / t1(i, k);
00079 }
00080
00081 W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
00082 }
00083 }
00084 }
00085 };
00086
00094 class HMultiplicativeDivergenceRule
00095 {
00096 public:
00097
00098 HMultiplicativeDivergenceRule() { }
00099
00108 template<typename MatType>
00109 inline static void Update(const MatType& V,
00110 const arma::mat& W,
00111 arma::mat& H)
00112 {
00113
00114 arma::mat t1;
00115 arma::colvec t2;
00116
00117 t1 = W * H;
00118 for (size_t i = 0; i < H.n_rows; i++)
00119 {
00120 for (size_t j = 0; j < H.n_cols; j++)
00121 {
00122
00123
00124
00125
00126 t2.set_size(W.n_rows);
00127 for (size_t k = 0; k < t2.n_elem; ++k)
00128 {
00129 t2(k) = W(k, i) * V(k, j) / t1(k, j);
00130 }
00131
00132 H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
00133 }
00134 }
00135 }
00136 };
00137
00138 };
00139 };
00140
00141 #endif