MLPACK  1.0.10
svd_incomplete_incremental_learning.hpp
Go to the documentation of this file.
1 #ifndef SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
2 #define SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
3 
4 namespace mlpack
5 {
6 namespace amf
7 {
9 {
10  public:
12  double kw = 0,
13  double kh = 0)
14  : u(u), kw(kw), kh(kh)
15  {}
16 
17  template<typename MatType>
18  void Initialize(const MatType& dataset, const size_t rank)
19  {
20  (void)rank;
21 
22  n = dataset.n_rows;
23  m = dataset.n_cols;
24 
25  currentUserIndex = 0;
26  }
27 
52  template<typename MatType>
53  inline void WUpdate(const MatType& V,
54  arma::mat& W,
55  const arma::mat& H)
56  {
57  arma::mat deltaW(n, W.n_cols);
58  deltaW.zeros();
59  for(size_t i = 0;i < n;i++)
60  {
61  double val;
62  if((val = V(i, currentUserIndex)) != 0)
63  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
64  arma::trans(H.col(currentUserIndex));
65  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
66  }
67 
68  W += u*deltaW;
69  }
70 
80  template<typename MatType>
81  inline void HUpdate(const MatType& V,
82  const arma::mat& W,
83  arma::mat& H)
84  {
85  arma::mat deltaH(H.n_rows, 1);
86  deltaH.zeros();
87 
88  for(size_t i = 0;i < n;i++)
89  {
90  double val;
91  if((val = V(i, currentUserIndex)) != 0)
92  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
93  arma::trans(W.row(i));
94  }
95  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
96 
97  H.col(currentUserIndex++) += u * deltaH;
99  }
100 
101  private:
102  double u;
103  double kw;
104  double kh;
105 
106  size_t n;
107  size_t m;
108 
110 };
111 
112 template<>
113 inline void SVDIncompleteIncrementalLearning::
114  WUpdate<arma::sp_mat>(const arma::sp_mat& V,
115  arma::mat& W,
116  const arma::mat& H)
117 {
118  arma::mat deltaW(n, W.n_cols);
119  deltaW.zeros();
120  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
121  it != V.end_col(currentUserIndex);it++)
122  {
123  double val = *it;
124  size_t i = it.row();
125  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
126  arma::trans(H.col(currentUserIndex));
127  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
128  }
129 
130  W += u*deltaW;
131 }
132 
133 template<>
134 inline void SVDIncompleteIncrementalLearning::
135  HUpdate<arma::sp_mat>(const arma::sp_mat& V,
136  const arma::mat& W,
137  arma::mat& H)
138 {
139  arma::mat deltaH(H.n_rows, 1);
140  deltaH.zeros();
141 
142  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
143  it != V.end_col(currentUserIndex);it++)
144  {
145  double val = *it;
146  size_t i = it.row();
147  if((val = V(i, currentUserIndex)) != 0)
148  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
149  arma::trans(W.row(i));
150  }
151  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
152 
153  H.col(currentUserIndex++) += u * deltaH;
154  currentUserIndex = currentUserIndex % m;
155 }
156 
157 }; // namepsace amf
158 }; // namespace mlpack
159 
160 
161 #endif // SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
162 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:31
void Initialize(const MatType &dataset, const size_t rank)
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
SVDIncompleteIncrementalLearning(double u=0.001, double kw=0, double kh=0)