00001 00030 #include <itpp/srccode/gmm.h> 00031 #include <itpp/srccode/vqtrain.h> 00032 #include <itpp/base/math/elem_math.h> 00033 #include <itpp/base/matfunc.h> 00034 #include <itpp/base/specmat.h> 00035 #include <itpp/base/random.h> 00036 #include <itpp/base/timing.h> 00037 #include <iostream> 00038 #include <fstream> 00039 00041 00042 namespace itpp 00043 { 00044 00045 GMM::GMM() 00046 { 00047 d = 0; 00048 M = 0; 00049 } 00050 00051 GMM::GMM(std::string filename) 00052 { 00053 load(filename); 00054 } 00055 00056 GMM::GMM(int M_in, int d_in) 00057 { 00058 M = M_in; 00059 d = d_in; 00060 m = zeros(M * d); 00061 sigma = zeros(M * d); 00062 w = 1. / M * ones(M); 00063 00064 for (int i = 0;i < M;i++) { 00065 w(i) = 1.0 / M; 00066 } 00067 compute_internals(); 00068 } 00069 00070 void GMM::init_from_vq(const vec &codebook, int dim) 00071 { 00072 00073 mat C(dim, dim); 00074 int i; 00075 vec v; 00076 00077 d = dim; 00078 M = codebook.length() / dim; 00079 00080 m = codebook; 00081 w = ones(M) / double(M); 00082 00083 C.clear(); 00084 for (i = 0;i < M;i++) { 00085 v = codebook.mid(i * d, d); 00086 C = C + outer_product(v, v); 00087 } 00088 C = 1. / M * C; 00089 sigma.set_length(M*d); 00090 for (i = 0;i < M;i++) { 00091 sigma.replace_mid(i*d, diag(C)); 00092 } 00093 00094 compute_internals(); 00095 } 00096 00097 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in) 00098 { 00099 int i, j; 00100 d = m_in.rows(); 00101 M = m_in.cols(); 00102 00103 m.set_length(M*d); 00104 sigma.set_length(M*d); 00105 for (i = 0;i < M;i++) { 00106 for (j = 0;j < d;j++) { 00107 m(i*d + j) = m_in(j, i); 00108 sigma(i*d + j) = sigma_in(j, i); 00109 } 00110 } 00111 w = w_in; 00112 00113 compute_internals(); 00114 } 00115 00116 void GMM::set_mean(const mat &m_in) 00117 { 00118 int i, j; 00119 00120 d = m_in.rows(); 00121 M = m_in.cols(); 00122 00123 m.set_length(M*d); 00124 for (i = 0;i < M;i++) { 00125 for (j = 0;j < d;j++) { 00126 m(i*d + j) = m_in(j, i); 00127 } 00128 } 00129 compute_internals(); 00130 } 00131 00132 void GMM::set_mean(int i, const vec &means, bool compflag) 00133 { 00134 m.replace_mid(i*length(means), means); 00135 if (compflag) compute_internals(); 00136 } 00137 00138 void GMM::set_covariance(const mat &sigma_in) 00139 { 00140 int i, j; 00141 00142 d = sigma_in.rows(); 00143 M = sigma_in.cols(); 00144 00145 sigma.set_length(M*d); 00146 for (i = 0;i < M;i++) { 00147 for (j = 0;j < d;j++) { 00148 sigma(i*d + j) = sigma_in(j, i); 00149 } 00150 } 00151 compute_internals(); 00152 } 00153 00154 void GMM::set_covariance(int i, const vec &covariances, bool compflag) 00155 { 00156 sigma.replace_mid(i*length(covariances), covariances); 00157 if (compflag) compute_internals(); 00158 } 00159 00160 void GMM::marginalize(int d_new) 00161 { 00162 it_error_if(d_new > d, "GMM.marginalize: cannot change to a larger dimension"); 00163 00164 vec mnew(d_new*M), sigmanew(d_new*M); 00165 int i, j; 00166 00167 for (i = 0;i < M;i++) { 00168 for (j = 0;j < d_new;j++) { 00169 mnew(i*d_new + j) = m(i * d + j); 00170 sigmanew(i*d_new + j) = sigma(i * d + j); 00171 } 00172 } 00173 m = mnew; 00174 sigma = sigmanew; 00175 d = d_new; 00176 00177 compute_internals(); 00178 } 00179 00180 void GMM::join(const GMM &newgmm) 00181 { 00182 if (d == 0) { 00183 w = newgmm.w; 00184 m = newgmm.m; 00185 sigma = newgmm.sigma; 00186 d = newgmm.d; 00187 M = newgmm.M; 00188 } 00189 else { 00190 it_error_if(d != newgmm.d, "GMM.join: cannot join GMMs of different dimension"); 00191 00192 w = concat(double(M) / (M + newgmm.M) * w, double(newgmm.M) / (M + newgmm.M) * newgmm.w); 00193 w = w / sum(w); 00194 m = concat(m, newgmm.m); 00195 sigma = concat(sigma, newgmm.sigma); 00196 00197 M = M + newgmm.M; 00198 } 00199 compute_internals(); 00200 } 00201 00202 void GMM::clear() 00203 { 00204 w.set_length(0); 00205 m.set_length(0); 00206 sigma.set_length(0); 00207 d = 0; 00208 M = 0; 00209 } 00210 00211 void GMM::save(std::string filename) 00212 { 00213 std::ofstream f(filename.c_str()); 00214 int i, j; 00215 00216 f << M << " " << d << std::endl ; 00217 for (i = 0;i < w.length();i++) { 00218 f << w(i) << std::endl ; 00219 } 00220 for (i = 0;i < M;i++) { 00221 f << m(i*d) ; 00222 for (j = 1;j < d;j++) { 00223 f << " " << m(i*d + j) ; 00224 } 00225 f << std::endl ; 00226 } 00227 for (i = 0;i < M;i++) { 00228 f << sigma(i*d) ; 00229 for (j = 1;j < d;j++) { 00230 f << " " << sigma(i*d + j) ; 00231 } 00232 f << std::endl ; 00233 } 00234 } 00235 00236 void GMM::load(std::string filename) 00237 { 00238 std::ifstream GMMFile(filename.c_str()); 00239 int i, j; 00240 00241 it_error_if(!GMMFile, std::string("GMM::load : cannot open file ") + filename); 00242 00243 GMMFile >> M >> d ; 00244 00245 00246 w.set_length(M); 00247 for (i = 0;i < M;i++) { 00248 GMMFile >> w(i) ; 00249 } 00250 m.set_length(M*d); 00251 for (i = 0;i < M;i++) { 00252 for (j = 0;j < d;j++) { 00253 GMMFile >> m(i*d + j) ; 00254 } 00255 } 00256 sigma.set_length(M*d); 00257 for (i = 0;i < M;i++) { 00258 for (j = 0;j < d;j++) { 00259 GMMFile >> sigma(i*d + j) ; 00260 } 00261 } 00262 compute_internals(); 00263 std::cout << " mixtures:" << M << " dim:" << d << std::endl ; 00264 } 00265 00266 double GMM::likelihood(const vec &x) 00267 { 00268 double fx = 0; 00269 int i; 00270 00271 for (i = 0;i < M;i++) { 00272 fx += w(i) * likelihood_aposteriori(x, i); 00273 } 00274 return fx; 00275 } 00276 00277 vec GMM::likelihood_aposteriori(const vec &x) 00278 { 00279 vec v(M); 00280 int i; 00281 00282 for (i = 0;i < M;i++) { 00283 v(i) = w(i) * likelihood_aposteriori(x, i); 00284 } 00285 return v; 00286 } 00287 00288 double GMM::likelihood_aposteriori(const vec &x, int mixture) 00289 { 00290 int j; 00291 double s; 00292 00293 it_error_if(d != x.length(), "GMM::likelihood_aposteriori : dimensions does not match"); 00294 s = 0; 00295 for (j = 0;j < d;j++) { 00296 s += normexp(mixture * d + j) * sqr(x(j) - m(mixture * d + j)); 00297 } 00298 return normweight(mixture)*std::exp(s);; 00299 } 00300 00301 void GMM::compute_internals() 00302 { 00303 int i, j; 00304 double s; 00305 double constant = 1.0 / std::pow(2 * pi, d / 2.0); 00306 00307 normweight.set_length(M); 00308 normexp.set_length(M*d); 00309 00310 for (i = 0;i < M;i++) { 00311 s = 1; 00312 for (j = 0;j < d;j++) { 00313 normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time 00314 s *= sigma(i * d + j); 00315 } 00316 normweight(i) = constant / std::sqrt(s); 00317 } 00318 00319 } 00320 00321 vec GMM::draw_sample() 00322 { 00323 static bool first = true; 00324 static vec cumweight; 00325 double u = randu(); 00326 int k; 00327 00328 if (first) { 00329 first = false; 00330 cumweight = cumsum(w); 00331 it_error_if(std::abs(cumweight(length(cumweight) - 1) - 1) > 1e-6, "weight does not sum to 0"); 00332 cumweight(length(cumweight) - 1) = 1; 00333 } 00334 k = 0; 00335 while (u > cumweight(k)) k++; 00336 00337 return elem_mult(sqrt(sigma.mid(k*d, d)), randn(d)) + m.mid(k*d, d); 00338 } 00339 00340 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE) 00341 { 00342 mat mean; 00343 int i, j, d = TrainingData(0).length(); 00344 vec sig; 00345 GMM gmm(M, d); 00346 vec m(d*M); 00347 vec sigma(d*M); 00348 vec w(M); 00349 vec normweight(M); 00350 vec normexp(d*M); 00351 double LL = 0, LLold, fx; 00352 double constant = 1.0 / std::pow(2 * pi, d / 2.0); 00353 int T = TrainingData.length(); 00354 vec x1; 00355 int t, n; 00356 vec msum(d*M); 00357 vec sigmasum(d*M); 00358 vec wsum(M); 00359 vec p_aposteriori(M); 00360 vec x2; 00361 double s; 00362 vec temp1, temp2; 00363 //double MINIMUM_VARIANCE=0.03; 00364 00365 //-----------initialization----------------------------------- 00366 00367 mean = vqtrain(TrainingData, M, 200000, 0.5, VERBOSE); 00368 for (i = 0;i < M;i++) gmm.set_mean(i, mean.get_col(i), false); 00369 // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false); 00370 sig = zeros(d); 00371 for (i = 0;i < TrainingData.length();i++) sig += sqr(TrainingData(i)); 00372 sig /= TrainingData.length(); 00373 for (i = 0;i < M;i++) gmm.set_covariance(i, 0.5*sig, false); 00374 00375 gmm.set_weight(1.0 / M*ones(M)); 00376 00377 //-----------optimization----------------------------------- 00378 00379 tic(); 00380 for (i = 0;i < M;i++) { 00381 temp1 = gmm.get_mean(i); 00382 temp2 = gmm.get_covariance(i); 00383 for (j = 0;j < d;j++) { 00384 m(i*d + j) = temp1(j); 00385 sigma(i*d + j) = temp2(j); 00386 } 00387 w(i) = gmm.get_weight(i); 00388 } 00389 for (n = 0;n < NOITER;n++) { 00390 for (i = 0;i < M;i++) { 00391 s = 1; 00392 for (j = 0;j < d;j++) { 00393 normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time 00394 s *= sigma(i * d + j); 00395 } 00396 normweight(i) = constant * w(i) / std::sqrt(s); 00397 } 00398 LLold = LL; 00399 wsum.clear(); 00400 msum.clear(); 00401 sigmasum.clear(); 00402 LL = 0; 00403 for (t = 0;t < T;t++) { 00404 x1 = TrainingData(t); 00405 x2 = sqr(x1); 00406 fx = 0; 00407 for (i = 0;i < M;i++) { 00408 s = 0; 00409 for (j = 0;j < d;j++) { 00410 s += normexp(i * d + j) * sqr(x1(j) - m(i * d + j)); 00411 } 00412 p_aposteriori(i) = normweight(i) * std::exp(s); 00413 fx += p_aposteriori(i); 00414 } 00415 p_aposteriori /= fx; 00416 LL = LL + std::log(fx); 00417 00418 for (i = 0;i < M;i++) { 00419 wsum(i) += p_aposteriori(i); 00420 for (j = 0;j < d;j++) { 00421 msum(i*d + j) += p_aposteriori(i) * x1(j); 00422 sigmasum(i*d + j) += p_aposteriori(i) * x2(j); 00423 } 00424 } 00425 } 00426 for (i = 0;i < M;i++) { 00427 for (j = 0;j < d;j++) { 00428 m(i*d + j) = msum(i * d + j) / wsum(i); 00429 sigma(i*d + j) = sigmasum(i * d + j) / wsum(i) - sqr(m(i * d + j)); 00430 } 00431 w(i) = wsum(i) / T; 00432 } 00433 LL = LL / T; 00434 00435 if (std::abs((LL - LLold) / LL) < 1e-6) break; 00436 if (VERBOSE) { 00437 std::cout << n << ": " << LL << " " << std::abs((LL - LLold) / LL) << " " << toc() << std::endl ; 00438 std::cout << "---------------------------------------" << std::endl ; 00439 tic(); 00440 } 00441 else { 00442 std::cout << n << ": LL = " << LL << " " << std::abs((LL - LLold) / LL) << "\r" ; 00443 std::cout.flush(); 00444 } 00445 } 00446 for (i = 0;i < M;i++) { 00447 gmm.set_mean(i, m.mid(i*d, d), false); 00448 gmm.set_covariance(i, sigma.mid(i*d, d), false); 00449 } 00450 gmm.set_weight(w); 00451 return gmm; 00452 } 00453 00454 } // namespace itpp 00455
Generated on Thu Apr 23 20:04:05 2009 for IT++ by Doxygen 1.5.8