IT++ Logo

mog_diag_em.cpp

Go to the documentation of this file.
00001 
00031 #include <itpp/stat/mog_diag_em.h>
00032 #include <itpp/base/math/log_exp.h>
00033 #include <itpp/base/timing.h>
00034 
00035 #include <iostream>
00036 #include <iomanip>
00037 
00038 namespace itpp {
00039 
00041   void inline MOG_diag_EM_sup::update_internals() {
00042 
00043     double Ddiv2_log_2pi = D/2.0 * std::log(m_2pi);
00044 
00045     for(int k=0;k<K;k++)  c_log_weights[k] = std::log(c_weights[k]);
00046 
00047     for(int k=0;k<K;k++) {
00048       double acc = 0.0;
00049       double * c_diag_cov = c_diag_covs[k];
00050       double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k];
00051 
00052       for(int d=0;d<D;d++) {
00053         double tmp = c_diag_cov[d];
00054         c_diag_cov_inv_etc[d] = 1.0/(2.0*tmp);
00055         acc += std::log(tmp);
00056         }
00057 
00058       c_log_det_etc[k] = -Ddiv2_log_2pi - 0.5*acc;
00059     }
00060 
00061   }
00062 
00063 
00065   void inline MOG_diag_EM_sup::sanitise_params() {
00066 
00067     double acc = 0.0;
00068     for(int k=0;k<K;k++) {
00069       if(c_weights[k] < weight_floor)  c_weights[k] = weight_floor;
00070       if(c_weights[k] > 1.0)  c_weights[k] = 1.0;
00071       acc += c_weights[k];
00072       }
00073     for(int k=0;k<K;k++)  c_weights[k] /= acc;
00074 
00075     for(int k=0;k<K;k++)
00076       for(int d=0;d<D;d++)
00077         if(c_diag_covs[k][d] < var_floor)  c_diag_covs[k][d] = var_floor;
00078 
00079   }
00080 
00082   double MOG_diag_EM_sup::ml_update_params() {
00083 
00084     double acc_loglhood = 0.0;
00085 
00086     for(int k=0;k<K;k++)  {
00087       c_acc_loglhood_K[k] = 0.0;
00088 
00089       double * c_acc_mean = c_acc_means[k];
00090       double * c_acc_cov  = c_acc_covs[k];
00091 
00092       for(int d=0;d<D;d++) { c_acc_mean[d] = 0.0; c_acc_cov[d] = 0.0; }
00093     }
00094 
00095     for(int n=0;n<N;n++) {
00096       double * c_x =  c_X[n];
00097 
00098       bool danger = paranoid;
00099       for(int k=0;k<K;k++)  {
00100         double tmp = c_log_weights[k] + MOG_diag::log_lhood_single_gaus_internal(c_x, k);
00101         c_tmpvecK[k] = tmp;
00102         if(tmp >= log_max_K)  danger = true;
00103       }
00104 
00105       if(danger) {
00106 
00107         double log_sum = c_tmpvecK[0];  for(int k=1;k<K;k++)  log_sum = log_add( log_sum, c_tmpvecK[k] );
00108         acc_loglhood += log_sum;
00109 
00110         for(int k=0;k<K;k++) {
00111 
00112           double * c_acc_mean = c_acc_means[k];
00113           double * c_acc_cov = c_acc_covs[k];
00114 
00115           double tmp_k = trunc_exp(c_tmpvecK[k] - log_sum);
00116           acc_loglhood_K[k] += tmp_k;
00117 
00118           for(int d=0;d<D;d++) {
00119             double tmp_x = c_x[d];
00120             c_acc_mean[d] +=  tmp_k * tmp_x;
00121             c_acc_cov[d] += tmp_k * tmp_x*tmp_x;
00122           }
00123         }
00124       }
00125       else {
00126 
00127         double sum = 0.0; for(int k=0;k<K;k++) { double tmp = std::exp(c_tmpvecK[k]); c_tmpvecK[k] = tmp; sum += tmp; }
00128         acc_loglhood += std::log(sum);
00129 
00130         for(int k=0;k<K;k++) {
00131 
00132           double * c_acc_mean = c_acc_means[k];
00133           double * c_acc_cov = c_acc_covs[k];
00134 
00135           double tmp_k = c_tmpvecK[k] / sum;
00136           c_acc_loglhood_K[k] += tmp_k;
00137 
00138           for(int d=0;d<D;d++) {
00139             double tmp_x = c_x[d];
00140             c_acc_mean[d] +=  tmp_k * tmp_x;
00141             c_acc_cov[d] += tmp_k * tmp_x*tmp_x;
00142           }
00143         }
00144       }
00145     }
00146 
00147     for(int k=0;k<K;k++) {
00148 
00149       double * c_mean = c_means[k];
00150       double * c_diag_cov = c_diag_covs[k];
00151 
00152       double * c_acc_mean = c_acc_means[k];
00153       double * c_acc_cov = c_acc_covs[k];
00154 
00155       double tmp_k = c_acc_loglhood_K[k];
00156 
00157       c_weights[k] = tmp_k / N;
00158 
00159       for(int d=0;d<D;d++) {
00160         double tmp_mean = c_acc_mean[d] / tmp_k;
00161         c_mean[d] = tmp_mean;
00162         c_diag_cov[d] = c_acc_cov[d] / tmp_k - tmp_mean*tmp_mean;
00163       }
00164     }
00165 
00166     return(acc_loglhood/N);
00167 
00168   }
00169 
00170 
00171   void MOG_diag_EM_sup::ml_iterate() {
00172     using std::cout;
00173     using std::endl;
00174     using std::setw;
00175     using std::showpos;
00176     using std::noshowpos;
00177     using std::scientific;
00178     using std::fixed;
00179     using std::flush;
00180     using std::setprecision;
00181 
00182     double avg_log_lhood_old = -1.0*std::numeric_limits<double>::max();
00183 
00184     Real_Timer tt;
00185 
00186     if(verbose) {
00187       cout << "MOG_diag_EM_sup::ml_iterate()" << endl;
00188       cout << setw(14) << "iteration";
00189       cout << setw(14) << "avg_loglhood";
00190       cout << setw(14) << "delta";
00191       cout << setw(10) << "toc";
00192       cout << endl;
00193     }
00194 
00195     for(int i=0; i<max_iter; i++) {
00196       sanitise_params();
00197       update_internals();
00198 
00199       if(verbose) tt.tic();
00200       double avg_log_lhood_new = ml_update_params();
00201 
00202       if(verbose) {
00203         double delta = avg_log_lhood_new - avg_log_lhood_old;
00204 
00205         cout << noshowpos << fixed;
00206         cout << setw(14) << i;
00207         cout << showpos << scientific << setprecision(3);
00208         cout << setw(14) << avg_log_lhood_new;
00209         cout << setw(14) << delta;
00210         cout << noshowpos << fixed;
00211         cout << setw(10) << tt.toc();
00212         cout << endl << flush;
00213       }
00214 
00215       if(avg_log_lhood_new <= avg_log_lhood_old)  break;
00216 
00217       avg_log_lhood_old = avg_log_lhood_new;
00218     }
00219   }
00220 
00221 
00222   void MOG_diag_EM_sup::ml(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) {
00223 
00224     it_assert(model_in.is_valid(), "MOG_diag_EM_sup::ml(): initial model not valid" );
00225     it_assert(check_array_uniformity(X_in), "MOG_diag_EM_sup::ml(): 'X' is empty or contains vectors of varying dimensionality" );
00226     it_assert( (max_iter_in > 0), "MOG_diag_EM_sup::ml(): 'max_iter' needs to be greater than zero" );
00227 
00228     verbose = verbose_in;
00229 
00230     N = X_in.size();
00231 
00232     Array<vec> means_in = model_in.get_means();
00233     Array<vec> diag_covs_in = model_in.get_diag_covs();
00234     vec weights_in = model_in.get_weights();
00235 
00236     init(means_in, diag_covs_in, weights_in);
00237 
00238     means_in.set_size(0); diag_covs_in.set_size(0); weights_in.set_size(0);
00239 
00240     if(K > N)    it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N");
00241     else
00242     if(K > N/10) it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N/10");
00243 
00244     var_floor = var_floor_in;
00245     weight_floor = weight_floor_in;
00246 
00247     const double tiny = std::numeric_limits<double>::min();
00248     if(var_floor < tiny) var_floor = tiny;
00249     if(weight_floor < tiny) weight_floor = tiny;
00250     if(weight_floor > 1.0/K ) weight_floor = 1.0/K;
00251 
00252     max_iter = max_iter_in;
00253 
00254     tmpvecK.set_size(K);
00255     tmpvecD.set_size(D);
00256     acc_loglhood_K.set_size(K);
00257 
00258     acc_means.set_size(K); for(int k=0;k<K;k++) acc_means(k).set_size(D);
00259     acc_covs.set_size(K);  for(int k=0;k<K;k++) acc_covs(k).set_size(D);
00260 
00261     c_X = enable_c_access(X_in);
00262     c_tmpvecK = enable_c_access(tmpvecK);
00263     c_tmpvecD = enable_c_access(tmpvecD);
00264     c_acc_loglhood_K = enable_c_access(acc_loglhood_K);
00265     c_acc_means = enable_c_access(acc_means);
00266     c_acc_covs = enable_c_access(acc_covs);
00267 
00268     ml_iterate();
00269 
00270     model_in.init(means, diag_covs, weights);
00271 
00272     disable_c_access(c_X);
00273     disable_c_access(c_tmpvecK);
00274     disable_c_access(c_tmpvecD);
00275     disable_c_access(c_acc_loglhood_K);
00276     disable_c_access(c_acc_means);
00277     disable_c_access(c_acc_covs);
00278 
00279 
00280     tmpvecK.set_size(0);
00281     tmpvecD.set_size(0);
00282     acc_loglhood_K.set_size(0);
00283     acc_means.set_size(0);
00284     acc_covs.set_size(0);
00285 
00286     cleanup();
00287 
00288   }
00289 
00290   void MOG_diag_EM_sup::map(MOG_diag &model_in, MOG_diag &prior_model_in, Array<vec> &X_in, int max_iter_in, double alpha_in, double var_floor_in, double weight_floor_in, bool verbose_in) {
00291     it_assert(false, "MOG_diag_EM_sup::map(): not implemented yet");
00292   }
00293 
00294 
00295   //
00296   // convenience functions
00297 
00298   void MOG_diag_ML(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) {
00299     MOG_diag_EM_sup EM;
00300     EM.ml(model_in, X_in, max_iter_in, var_floor_in, weight_floor_in, verbose_in);
00301   }
00302 
00303   void MOG_diag_MAP(MOG_diag &model_in, MOG_diag &prior_model_in, Array<vec> &X_in, int max_iter_in, double alpha_in, double var_floor_in, double weight_floor_in, bool verbose_in) {
00304     it_assert(false, "MOG_diag_MAP(): not implemented yet");
00305   }
00306 
00307 }
00308 
SourceForge Logo

Generated on Sat Apr 19 10:59:23 2008 for IT++ by Doxygen 1.5.5