00001 00033 #include <itpp/srccode/gmm.h> 00034 #include <itpp/srccode/vqtrain.h> 00035 #include <itpp/base/matfunc.h> 00036 #include <itpp/base/specmat.h> 00037 #include <itpp/base/random.h> 00038 #include <itpp/base/timing.h> 00039 #include <itpp/base/elmatfunc.h> 00040 #include <iostream> 00041 #include <fstream> 00042 00043 00044 namespace itpp { 00045 00046 GMM::GMM() 00047 { 00048 d=0; 00049 M=0; 00050 } 00051 00052 GMM::GMM(std::string filename) 00053 { 00054 load(filename); 00055 } 00056 00057 GMM::GMM(int M_in, int d_in) 00058 { 00059 M=M_in; 00060 d=d_in; 00061 m=zeros(M*d); 00062 sigma=zeros(M*d); 00063 w=1./M*ones(M); 00064 00065 for (int i=0;i<M;i++) { 00066 w(i)=1.0/M; 00067 } 00068 compute_internals(); 00069 } 00070 00071 void GMM::init_from_vq(const vec &codebook, int dim) 00072 { 00073 00074 mat C(dim,dim); 00075 int i; 00076 vec v; 00077 00078 d=dim; 00079 M=codebook.length()/dim; 00080 00081 m=codebook; 00082 w=ones(M)/double(M); 00083 00084 C.clear(); 00085 for (i=0;i<M;i++) { 00086 v=codebook.mid(i*d,d); 00087 C=C+outer_product(v,v); 00088 } 00089 C=1./M*C; 00090 sigma.set_length(M*d); 00091 for (i=0;i<M;i++) { 00092 sigma.replace_mid(i*d,diag(C)); 00093 } 00094 00095 compute_internals(); 00096 } 00097 00098 //void GMM::init(const vec &m_in, const vec &sigma_in, const vec &w_in) 00099 //{ 00100 // m=m_in; 00101 // sigma=sigma_in; 00102 // w=w_in; 00103 // 00104 // compute_internals(); 00105 //} 00106 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in) 00107 { 00108 int i,j; 00109 d=m_in.rows(); 00110 M=m_in.cols(); 00111 00112 m.set_length(M*d); 00113 sigma.set_length(M*d); 00114 for (i=0;i<M;i++) { 00115 for (j=0;j<d;j++) { 00116 m(i*d+j)=m_in(j,i); 00117 sigma(i*d+j)=sigma_in(j,i); 00118 } 00119 } 00120 w=w_in; 00121 00122 compute_internals(); 00123 } 00124 00125 void GMM::set_mean(const mat &m_in) 00126 { 00127 int i,j; 00128 00129 d=m_in.rows(); 00130 M=m_in.cols(); 00131 00132 m.set_length(M*d); 00133 for (i=0;i<M;i++) { 00134 for (j=0;j<d;j++) { 00135 m(i*d+j)=m_in(j,i); 00136 } 00137 } 00138 compute_internals(); 00139 } 00140 00141 void GMM::set_mean(int i, const vec &means, bool compflag) 00142 { 00143 m.replace_mid(i*length(means),means); 00144 if (compflag) compute_internals(); 00145 } 00146 00147 void GMM::set_covariance(const mat &sigma_in) 00148 { 00149 int i,j; 00150 00151 d=sigma_in.rows(); 00152 M=sigma_in.cols(); 00153 00154 sigma.set_length(M*d); 00155 for (i=0;i<M;i++) { 00156 for (j=0;j<d;j++) { 00157 sigma(i*d+j)=sigma_in(j,i); 00158 } 00159 } 00160 compute_internals(); 00161 } 00162 00163 void GMM::set_covariance(int i, const vec &covariances, bool compflag) 00164 { 00165 sigma.replace_mid(i*length(covariances),covariances); 00166 if (compflag) compute_internals(); 00167 } 00168 00169 void GMM::marginalize(int d_new) 00170 { 00171 it_error_if(d_new>d,"GMM.marginalize: cannot change to a larger dimension"); 00172 00173 vec mnew(d_new*M),sigmanew(d_new*M); 00174 int i,j; 00175 00176 for (i=0;i<M;i++) { 00177 for (j=0;j<d_new;j++) { 00178 mnew(i*d_new+j)=m(i*d+j); 00179 sigmanew(i*d_new+j)=sigma(i*d+j); 00180 } 00181 } 00182 m=mnew; 00183 sigma=sigmanew; 00184 d=d_new; 00185 00186 compute_internals(); 00187 } 00188 00189 void GMM::join(const GMM &newgmm) 00190 { 00191 if (d==0) { 00192 w=newgmm.w; 00193 m=newgmm.m; 00194 sigma=newgmm.sigma; 00195 d=newgmm.d; 00196 M=newgmm.M; 00197 } else { 00198 it_error_if( d!=newgmm.d,"GMM.join: cannot join GMMs of different dimension"); 00199 00200 w=concat(double(M)/(M+newgmm.M)*w,double(newgmm.M)/(M+newgmm.M)*newgmm.w); 00201 w=w/sum(w); 00202 m=concat(m,newgmm.m); 00203 sigma=concat(sigma,newgmm.sigma); 00204 00205 M=M+newgmm.M; 00206 } 00207 compute_internals(); 00208 } 00209 00210 void GMM::clear() 00211 { 00212 w.set_length(0); 00213 m.set_length(0); 00214 sigma.set_length(0); 00215 d=0; 00216 M=0; 00217 } 00218 00219 void GMM::save(std::string filename) 00220 { 00221 std::ofstream f(filename.c_str()); 00222 int i,j; 00223 00224 f << M << " " << d << std::endl ; 00225 for (i=0;i<w.length();i++) { 00226 f << w(i) << std::endl ; 00227 } 00228 for (i=0;i<M;i++) { 00229 f << m(i*d) ; 00230 for (j=1;j<d;j++) { 00231 f << " " << m(i*d+j) ; 00232 } 00233 f << std::endl ; 00234 } 00235 for (i=0;i<M;i++) { 00236 f << sigma(i*d) ; 00237 for (j=1;j<d;j++) { 00238 f << " " << sigma(i*d+j) ; 00239 } 00240 f << std::endl ; 00241 } 00242 } 00243 00244 void GMM::load(std::string filename) 00245 { 00246 std::ifstream GMMFile(filename.c_str()); 00247 long i,j; 00248 00249 it_error_if(!GMMFile,std::string("GMM::load : cannot open file ")+filename); 00250 00251 GMMFile >> M >> d ; 00252 00253 00254 w.set_length(M); 00255 for (i=0;i<M;i++) { 00256 GMMFile >> w(i) ; 00257 } 00258 m.set_length(M*d); 00259 for (i=0;i<M;i++) { 00260 for (j=0;j<d;j++) { 00261 GMMFile >> m(i*d+j) ; 00262 } 00263 } 00264 sigma.set_length(M*d); 00265 for (i=0;i<M;i++) { 00266 for (j=0;j<d;j++) { 00267 GMMFile >> sigma(i*d+j) ; 00268 } 00269 } 00270 compute_internals(); 00271 std::cout << " mixtures:" << M << " dim:" << d << std::endl ; 00272 } 00273 00274 double GMM::likelihood(const vec &x) 00275 { 00276 double fx=0; 00277 int i; 00278 00279 for (i=0;i<M;i++) { 00280 fx+=w(i)*likelihood_aposteriori(x, i); 00281 } 00282 return fx; 00283 } 00284 00285 vec GMM::likelihood_aposteriori(const vec &x) 00286 { 00287 vec v(M); 00288 int i; 00289 00290 for (i=0;i<M;i++) { 00291 v(i)=w(i)*likelihood_aposteriori(x, i); 00292 } 00293 return v; 00294 } 00295 00296 double GMM::likelihood_aposteriori(const vec &x, int mixture) 00297 { 00298 int j; 00299 double s; 00300 00301 it_error_if(d!=x.length(),"GMM::likelihood_aposteriori : dimensions does not match"); 00302 s=0; 00303 for (j=0;j<d;j++) { 00304 s+=normexp(mixture*d+j)*sqr(x(j)-m(mixture*d+j)); 00305 } 00306 return normweight(mixture)*std::exp(s);; 00307 } 00308 00309 void GMM::compute_internals() 00310 { 00311 int i,j; 00312 double s; 00313 double constant=1.0/std::pow(2*pi,d/2.0); 00314 00315 normweight.set_length(M); 00316 normexp.set_length(M*d); 00317 00318 for (i=0;i<M;i++) { 00319 s=1; 00320 for (j=0;j<d;j++) { 00321 normexp(i*d+j)=-0.5/sigma(i*d+j); // check time 00322 s*=sigma(i*d+j); 00323 } 00324 normweight(i) = constant/std::sqrt(s); 00325 } 00326 00327 } 00328 00329 vec GMM::draw_sample() 00330 { 00331 static bool first=true; 00332 static vec cumweight; 00333 double u=randu(); 00334 int k; 00335 00336 if (first) { 00337 first=false; 00338 cumweight=cumsum(w); 00339 it_error_if(std::abs(cumweight(length(cumweight)-1)-1)>1e-6,"weight does not sum to 0"); 00340 cumweight(length(cumweight)-1)=1; 00341 } 00342 k=0; 00343 while (u>cumweight(k)) k++; 00344 00345 return elem_mult(sqrt(sigma.mid(k*d,d)),randn(d))+m.mid(k*d,d); 00346 } 00347 00348 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE) 00349 { 00350 mat mean; 00351 int i,j,d=TrainingData(0).length(); 00352 vec sig; 00353 GMM gmm(M,d); 00354 vec m(d*M); 00355 vec sigma(d*M); 00356 vec w(M); 00357 vec normweight(M); 00358 vec normexp(d*M); 00359 double LL=0,LLold,fx; 00360 double constant=1.0/std::pow(2*pi,d/2.0); 00361 int T=TrainingData.length(); 00362 vec x1; 00363 int t,n; 00364 vec msum(d*M); 00365 vec sigmasum(d*M); 00366 vec wsum(M); 00367 vec p_aposteriori(M); 00368 vec x2; 00369 double s; 00370 vec temp1,temp2; 00371 //double MINIMUM_VARIANCE=0.03; 00372 00373 //-----------initialization----------------------------------- 00374 00375 mean=vqtrain(TrainingData,M,200000,0.5,VERBOSE); 00376 for (i=0;i<M;i++) gmm.set_mean(i,mean.get_col(i),false); 00377 // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false); 00378 sig=zeros(d); 00379 for (i=0;i<TrainingData.length();i++) sig+=sqr(TrainingData(i)); 00380 sig/=TrainingData.length(); 00381 for (i=0;i<M;i++) gmm.set_covariance(i,0.5*sig,false); 00382 00383 gmm.set_weight(1.0/M*ones(M)); 00384 00385 //-----------optimization----------------------------------- 00386 00387 tic(); 00388 for (i=0;i<M;i++) { 00389 temp1=gmm.get_mean(i); 00390 temp2=gmm.get_covariance(i); 00391 for (j=0;j<d;j++) { 00392 m(i*d+j)=temp1(j); 00393 sigma(i*d+j)=temp2(j); 00394 } 00395 w(i)=gmm.get_weight(i); 00396 } 00397 for (n=0;n<NOITER;n++) { 00398 for (i=0;i<M;i++) { 00399 s=1; 00400 for (j=0;j<d;j++) { 00401 normexp(i*d+j)=-0.5/sigma(i*d+j); // check time 00402 s*=sigma(i*d+j); 00403 } 00404 normweight(i) = constant*w(i)/std::sqrt(s); 00405 } 00406 LLold=LL; 00407 wsum.clear(); 00408 msum.clear(); 00409 sigmasum.clear(); 00410 LL=0; 00411 for (t=0;t<T;t++) { 00412 x1=TrainingData(t); 00413 x2=sqr(x1); 00414 fx=0; 00415 for (i=0;i<M;i++) { 00416 s=0; 00417 for (j=0;j<d;j++) { 00418 s+=normexp(i*d+j)*sqr(x1(j)-m(i*d+j)); 00419 } 00420 p_aposteriori(i)=normweight(i)*std::exp(s); 00421 fx+=p_aposteriori(i); 00422 } 00423 p_aposteriori/=fx; 00424 LL=LL+std::log(fx); 00425 00426 for (i=0;i<M;i++) { 00427 wsum(i)+=p_aposteriori(i); 00428 for (j=0;j<d;j++) { 00429 msum(i*d+j)+=p_aposteriori(i)*x1(j); 00430 sigmasum(i*d+j)+=p_aposteriori(i)*x2(j); 00431 } 00432 } 00433 } 00434 for (i=0;i<M;i++) { 00435 for (j=0;j<d;j++) { 00436 m(i*d+j)=msum(i*d+j)/wsum(i); 00437 sigma(i*d+j)=sigmasum(i*d+j)/wsum(i)-sqr(m(i*d+j)); 00438 } 00439 w(i)=wsum(i)/T; 00440 } 00441 LL=LL/T; 00442 00443 if (std::abs((LL-LLold)/LL) < 1e-6) break; 00444 if (VERBOSE) { 00445 std::cout << n << ": " << LL << " " << std::abs((LL-LLold)/LL) << " " << toc() << std::endl ; 00446 std::cout << "---------------------------------------" << std::endl ; 00447 tic(); 00448 } else { 00449 std::cout << n << ": LL = " << LL << " " << std::abs((LL-LLold)/LL) << "\r" ;std::cout.flush(); 00450 } 00451 } 00452 for (i=0;i<M;i++) { 00453 gmm.set_mean(i,m.mid(i*d,d),false); 00454 gmm.set_covariance(i,sigma.mid(i*d,d),false); 00455 } 00456 gmm.set_weight(w); 00457 return gmm; 00458 } 00459 00460 } // namespace itpp
Generated on Wed Apr 18 11:45:35 2007 for IT++ by Doxygen 1.5.2