IT++ Logo

gmm.cpp

Go to the documentation of this file.
00001 
00030 #include <itpp/srccode/gmm.h>
00031 #include <itpp/srccode/vqtrain.h>
00032 #include <itpp/base/math/elem_math.h>
00033 #include <itpp/base/matfunc.h>
00034 #include <itpp/base/specmat.h>
00035 #include <itpp/base/random.h>
00036 #include <itpp/base/timing.h>
00037 #include <iostream>
00038 #include <fstream>
00039 
00041 
00042 namespace itpp {
00043 
00044   GMM::GMM()
00045   {
00046     d=0;
00047     M=0;
00048   }
00049 
00050   GMM::GMM(std::string filename)
00051   {
00052     load(filename);
00053   }
00054 
00055   GMM::GMM(int M_in, int d_in)
00056   {
00057     M=M_in;
00058     d=d_in;
00059     m=zeros(M*d);
00060     sigma=zeros(M*d);
00061     w=1./M*ones(M);
00062 
00063     for (int i=0;i<M;i++) {
00064       w(i)=1.0/M;
00065     }
00066     compute_internals();
00067   }
00068 
00069   void GMM::init_from_vq(const vec &codebook, int dim)
00070   {
00071 
00072     mat         C(dim,dim);
00073     int         i;
00074     vec         v;
00075 
00076     d=dim;
00077     M=codebook.length()/dim;
00078 
00079     m=codebook;
00080     w=ones(M)/double(M);
00081 
00082     C.clear();
00083     for (i=0;i<M;i++) {
00084       v=codebook.mid(i*d,d);
00085       C=C+outer_product(v,v);
00086     }
00087     C=1./M*C;
00088     sigma.set_length(M*d);
00089     for (i=0;i<M;i++) {
00090       sigma.replace_mid(i*d,diag(C));
00091     }
00092 
00093     compute_internals();
00094   }
00095 
00096   void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
00097   {
00098     int         i,j;
00099     d=m_in.rows();
00100     M=m_in.cols();
00101 
00102     m.set_length(M*d);
00103     sigma.set_length(M*d);
00104     for (i=0;i<M;i++) {
00105       for (j=0;j<d;j++) {
00106         m(i*d+j)=m_in(j,i);
00107         sigma(i*d+j)=sigma_in(j,i);
00108       }
00109     }
00110     w=w_in;
00111 
00112     compute_internals();
00113   }
00114 
00115   void GMM::set_mean(const mat &m_in)
00116   {
00117     int         i,j;
00118 
00119     d=m_in.rows();
00120     M=m_in.cols();
00121 
00122     m.set_length(M*d);
00123     for (i=0;i<M;i++) {
00124       for (j=0;j<d;j++) {
00125         m(i*d+j)=m_in(j,i);
00126       }
00127     }
00128     compute_internals();
00129   }
00130 
00131   void GMM::set_mean(int i, const vec &means, bool compflag)
00132   {
00133     m.replace_mid(i*length(means),means);
00134     if (compflag) compute_internals();
00135   }
00136 
00137   void GMM::set_covariance(const mat &sigma_in)
00138   {
00139     int         i,j;
00140 
00141     d=sigma_in.rows();
00142     M=sigma_in.cols();
00143 
00144     sigma.set_length(M*d);
00145     for (i=0;i<M;i++) {
00146       for (j=0;j<d;j++) {
00147         sigma(i*d+j)=sigma_in(j,i);
00148       }
00149     }
00150     compute_internals();
00151   }
00152 
00153   void GMM::set_covariance(int i, const vec &covariances, bool compflag)
00154   {
00155     sigma.replace_mid(i*length(covariances),covariances);
00156     if (compflag) compute_internals();
00157   }
00158 
00159   void GMM::marginalize(int d_new)
00160   {
00161     it_error_if(d_new>d,"GMM.marginalize: cannot change to a larger dimension");
00162 
00163     vec         mnew(d_new*M),sigmanew(d_new*M);
00164     int         i,j;
00165 
00166     for (i=0;i<M;i++) {
00167       for (j=0;j<d_new;j++) {
00168         mnew(i*d_new+j)=m(i*d+j);
00169         sigmanew(i*d_new+j)=sigma(i*d+j);
00170       }
00171     }
00172     m=mnew;
00173     sigma=sigmanew;
00174     d=d_new;
00175 
00176     compute_internals();
00177   }
00178 
00179   void GMM::join(const GMM &newgmm)
00180   {
00181     if (d==0) {
00182       w=newgmm.w;
00183       m=newgmm.m;
00184       sigma=newgmm.sigma;
00185       d=newgmm.d;
00186       M=newgmm.M;
00187     } else {
00188       it_error_if( d!=newgmm.d,"GMM.join: cannot join GMMs of different dimension");
00189 
00190       w=concat(double(M)/(M+newgmm.M)*w,double(newgmm.M)/(M+newgmm.M)*newgmm.w);
00191       w=w/sum(w);
00192       m=concat(m,newgmm.m);
00193       sigma=concat(sigma,newgmm.sigma);
00194 
00195       M=M+newgmm.M;
00196     }
00197     compute_internals();
00198   }
00199 
00200   void GMM::clear()
00201   {
00202     w.set_length(0);
00203     m.set_length(0);
00204     sigma.set_length(0);
00205     d=0;
00206     M=0;
00207   }
00208 
00209   void GMM::save(std::string filename)
00210   {
00211     std::ofstream       f(filename.c_str());
00212     int                 i,j;
00213 
00214     f << M << " " << d << std::endl ;
00215     for (i=0;i<w.length();i++) {
00216       f << w(i) << std::endl ;
00217     }
00218     for (i=0;i<M;i++) {
00219       f << m(i*d) ;
00220       for (j=1;j<d;j++) {
00221         f << " " << m(i*d+j) ;
00222       }
00223       f << std::endl ;
00224     }
00225     for (i=0;i<M;i++) {
00226       f << sigma(i*d) ;
00227       for (j=1;j<d;j++) {
00228         f << " " << sigma(i*d+j) ;
00229       }
00230       f << std::endl ;
00231     }
00232   }
00233 
00234   void GMM::load(std::string filename)
00235   {
00236     std::ifstream       GMMFile(filename.c_str());
00237     int                 i,j;
00238 
00239     it_error_if(!GMMFile,std::string("GMM::load : cannot open file ")+filename);
00240 
00241     GMMFile >> M >> d ;
00242 
00243 
00244     w.set_length(M);
00245     for (i=0;i<M;i++) {
00246       GMMFile >> w(i) ;
00247     }
00248     m.set_length(M*d);
00249     for (i=0;i<M;i++) {
00250       for (j=0;j<d;j++) {
00251         GMMFile >> m(i*d+j) ;
00252       }
00253     }
00254     sigma.set_length(M*d);
00255     for (i=0;i<M;i++) {
00256       for (j=0;j<d;j++) {
00257         GMMFile >> sigma(i*d+j) ;
00258       }
00259     }
00260     compute_internals();
00261     std::cout << "  mixtures:" << M << "  dim:" << d << std::endl ;
00262   }
00263 
00264   double GMM::likelihood(const vec &x)
00265   {
00266     double      fx=0;
00267     int         i;
00268 
00269     for (i=0;i<M;i++) {
00270       fx+=w(i)*likelihood_aposteriori(x, i);
00271     }
00272     return fx;
00273   }
00274 
00275   vec GMM::likelihood_aposteriori(const vec &x)
00276   {
00277     vec         v(M);
00278     int         i;
00279 
00280     for (i=0;i<M;i++) {
00281       v(i)=w(i)*likelihood_aposteriori(x, i);
00282     }
00283     return v;
00284   }
00285 
00286   double GMM::likelihood_aposteriori(const vec &x, int mixture)
00287   {
00288     int         j;
00289     double      s;
00290 
00291     it_error_if(d!=x.length(),"GMM::likelihood_aposteriori : dimensions does not match");
00292     s=0;
00293     for (j=0;j<d;j++) {
00294       s+=normexp(mixture*d+j)*sqr(x(j)-m(mixture*d+j));
00295     }
00296     return normweight(mixture)*std::exp(s);;
00297   }
00298 
00299   void GMM::compute_internals()
00300   {
00301     int         i,j;
00302     double      s;
00303     double      constant=1.0/std::pow(2*pi,d/2.0);
00304 
00305     normweight.set_length(M);
00306     normexp.set_length(M*d);
00307 
00308     for (i=0;i<M;i++) {
00309       s=1;
00310       for (j=0;j<d;j++) {
00311         normexp(i*d+j)=-0.5/sigma(i*d+j);  // check time
00312         s*=sigma(i*d+j);
00313       }
00314       normweight(i) = constant/std::sqrt(s);
00315     }
00316 
00317   }
00318 
00319   vec GMM::draw_sample()
00320   {
00321     static bool first=true;
00322     static vec  cumweight;
00323     double      u=randu();
00324     int         k;
00325 
00326     if (first) {
00327       first=false;
00328       cumweight=cumsum(w);
00329       it_error_if(std::abs(cumweight(length(cumweight)-1)-1)>1e-6,"weight does not sum to 0");
00330       cumweight(length(cumweight)-1)=1;
00331     }
00332     k=0;
00333     while (u>cumweight(k)) k++;
00334 
00335     return elem_mult(sqrt(sigma.mid(k*d,d)),randn(d))+m.mid(k*d,d);
00336   }
00337 
00338   GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
00339   {
00340     mat                 mean;
00341     int                 i,j,d=TrainingData(0).length();
00342     vec                 sig;
00343     GMM                 gmm(M,d);
00344     vec                 m(d*M);
00345     vec                 sigma(d*M);
00346     vec                 w(M);
00347     vec                 normweight(M);
00348     vec                 normexp(d*M);
00349     double              LL=0,LLold,fx;
00350     double              constant=1.0/std::pow(2*pi,d/2.0);
00351     int                 T=TrainingData.length();
00352     vec                 x1;
00353     int                 t,n;
00354     vec                 msum(d*M);
00355     vec                 sigmasum(d*M);
00356     vec                 wsum(M);
00357     vec                 p_aposteriori(M);
00358     vec                 x2;
00359     double              s;
00360     vec                 temp1,temp2;
00361     //double            MINIMUM_VARIANCE=0.03;
00362 
00363     //-----------initialization-----------------------------------
00364 
00365     mean=vqtrain(TrainingData,M,200000,0.5,VERBOSE);
00366     for (i=0;i<M;i++) gmm.set_mean(i,mean.get_col(i),false);
00367     //  for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
00368     sig=zeros(d);
00369     for (i=0;i<TrainingData.length();i++) sig+=sqr(TrainingData(i));
00370     sig/=TrainingData.length();
00371     for (i=0;i<M;i++) gmm.set_covariance(i,0.5*sig,false);
00372 
00373     gmm.set_weight(1.0/M*ones(M));
00374 
00375     //-----------optimization-----------------------------------
00376 
00377     tic();
00378     for (i=0;i<M;i++) {
00379       temp1=gmm.get_mean(i);
00380       temp2=gmm.get_covariance(i);
00381       for (j=0;j<d;j++) {
00382         m(i*d+j)=temp1(j);
00383         sigma(i*d+j)=temp2(j);
00384       }
00385       w(i)=gmm.get_weight(i);
00386     }
00387     for (n=0;n<NOITER;n++) {
00388       for (i=0;i<M;i++) {
00389         s=1;
00390         for (j=0;j<d;j++) {
00391           normexp(i*d+j)=-0.5/sigma(i*d+j);  // check time
00392           s*=sigma(i*d+j);
00393         }
00394         normweight(i) = constant*w(i)/std::sqrt(s);
00395       }
00396       LLold=LL;
00397       wsum.clear();
00398       msum.clear();
00399       sigmasum.clear();
00400       LL=0;
00401       for (t=0;t<T;t++) {
00402         x1=TrainingData(t);
00403         x2=sqr(x1);
00404         fx=0;
00405         for (i=0;i<M;i++) {
00406           s=0;
00407           for (j=0;j<d;j++) {
00408             s+=normexp(i*d+j)*sqr(x1(j)-m(i*d+j));
00409           }
00410           p_aposteriori(i)=normweight(i)*std::exp(s);
00411           fx+=p_aposteriori(i);
00412         }
00413         p_aposteriori/=fx;
00414         LL=LL+std::log(fx);
00415 
00416         for (i=0;i<M;i++) {
00417           wsum(i)+=p_aposteriori(i);
00418           for (j=0;j<d;j++) {
00419             msum(i*d+j)+=p_aposteriori(i)*x1(j);
00420             sigmasum(i*d+j)+=p_aposteriori(i)*x2(j);
00421           }
00422         }
00423       }
00424       for (i=0;i<M;i++) {
00425         for (j=0;j<d;j++) {
00426           m(i*d+j)=msum(i*d+j)/wsum(i);
00427           sigma(i*d+j)=sigmasum(i*d+j)/wsum(i)-sqr(m(i*d+j));
00428         }
00429         w(i)=wsum(i)/T;
00430       }
00431       LL=LL/T;
00432 
00433       if (std::abs((LL-LLold)/LL) < 1e-6) break;
00434       if (VERBOSE) {
00435         std::cout << n << ":   " << LL << "   " << std::abs((LL-LLold)/LL) << "   " << toc() <<  std::endl ;
00436         std::cout << "---------------------------------------" << std::endl ;
00437         tic();
00438       } else {
00439         std::cout << n << ": LL =  " << LL << "   " << std::abs((LL-LLold)/LL) << "\r" ;std::cout.flush();
00440       }
00441     }
00442     for (i=0;i<M;i++) {
00443       gmm.set_mean(i,m.mid(i*d,d),false);
00444       gmm.set_covariance(i,sigma.mid(i*d,d),false);
00445     }
00446     gmm.set_weight(w);
00447     return gmm;
00448   }
00449 
00450 } // namespace itpp
00451 
SourceForge Logo

Generated on Sat Apr 19 10:42:06 2008 for IT++ by Doxygen 1.5.5