00001 00031 #include <itpp/stat/mog_diag_em.h> 00032 #include <itpp/base/math/log_exp.h> 00033 #include <itpp/base/timing.h> 00034 00035 #include <iostream> 00036 #include <iomanip> 00037 00038 namespace itpp 00039 { 00040 00042 void inline MOG_diag_EM_sup::update_internals() 00043 { 00044 00045 double Ddiv2_log_2pi = D / 2.0 * std::log(m_2pi); 00046 00047 for (int k = 0;k < K;k++) c_log_weights[k] = std::log(c_weights[k]); 00048 00049 for (int k = 0;k < K;k++) { 00050 double acc = 0.0; 00051 double * c_diag_cov = c_diag_covs[k]; 00052 double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k]; 00053 00054 for (int d = 0;d < D;d++) { 00055 double tmp = c_diag_cov[d]; 00056 c_diag_cov_inv_etc[d] = 1.0 / (2.0 * tmp); 00057 acc += std::log(tmp); 00058 } 00059 00060 c_log_det_etc[k] = -Ddiv2_log_2pi - 0.5 * acc; 00061 } 00062 00063 } 00064 00065 00067 void inline MOG_diag_EM_sup::sanitise_params() 00068 { 00069 00070 double acc = 0.0; 00071 for (int k = 0;k < K;k++) { 00072 if (c_weights[k] < weight_floor) c_weights[k] = weight_floor; 00073 if (c_weights[k] > 1.0) c_weights[k] = 1.0; 00074 acc += c_weights[k]; 00075 } 00076 for (int k = 0;k < K;k++) c_weights[k] /= acc; 00077 00078 for (int k = 0;k < K;k++) 00079 for (int d = 0;d < D;d++) 00080 if (c_diag_covs[k][d] < var_floor) c_diag_covs[k][d] = var_floor; 00081 00082 } 00083 00085 double MOG_diag_EM_sup::ml_update_params() 00086 { 00087 00088 double acc_loglhood = 0.0; 00089 00090 for (int k = 0;k < K;k++) { 00091 c_acc_loglhood_K[k] = 0.0; 00092 00093 double * c_acc_mean = c_acc_means[k]; 00094 double * c_acc_cov = c_acc_covs[k]; 00095 00096 for (int d = 0;d < D;d++) { c_acc_mean[d] = 0.0; c_acc_cov[d] = 0.0; } 00097 } 00098 00099 for (int n = 0;n < N;n++) { 00100 double * c_x = c_X[n]; 00101 00102 bool danger = paranoid; 00103 for (int k = 0;k < K;k++) { 00104 double tmp = c_log_weights[k] + MOG_diag::log_lhood_single_gaus_internal(c_x, k); 00105 c_tmpvecK[k] = tmp; 00106 if (tmp >= log_max_K) danger = true; 00107 } 00108 00109 if (danger) { 00110 00111 double log_sum = c_tmpvecK[0]; 00112 for (int k = 1;k < K;k++) log_sum = log_add(log_sum, c_tmpvecK[k]); 00113 acc_loglhood += log_sum; 00114 00115 for (int k = 0;k < K;k++) { 00116 00117 double * c_acc_mean = c_acc_means[k]; 00118 double * c_acc_cov = c_acc_covs[k]; 00119 00120 double tmp_k = trunc_exp(c_tmpvecK[k] - log_sum); 00121 acc_loglhood_K[k] += tmp_k; 00122 00123 for (int d = 0;d < D;d++) { 00124 double tmp_x = c_x[d]; 00125 c_acc_mean[d] += tmp_k * tmp_x; 00126 c_acc_cov[d] += tmp_k * tmp_x * tmp_x; 00127 } 00128 } 00129 } 00130 else { 00131 00132 double sum = 0.0; 00133 for (int k = 0;k < K;k++) { double tmp = std::exp(c_tmpvecK[k]); c_tmpvecK[k] = tmp; sum += tmp; } 00134 acc_loglhood += std::log(sum); 00135 00136 for (int k = 0;k < K;k++) { 00137 00138 double * c_acc_mean = c_acc_means[k]; 00139 double * c_acc_cov = c_acc_covs[k]; 00140 00141 double tmp_k = c_tmpvecK[k] / sum; 00142 c_acc_loglhood_K[k] += tmp_k; 00143 00144 for (int d = 0;d < D;d++) { 00145 double tmp_x = c_x[d]; 00146 c_acc_mean[d] += tmp_k * tmp_x; 00147 c_acc_cov[d] += tmp_k * tmp_x * tmp_x; 00148 } 00149 } 00150 } 00151 } 00152 00153 for (int k = 0;k < K;k++) { 00154 00155 double * c_mean = c_means[k]; 00156 double * c_diag_cov = c_diag_covs[k]; 00157 00158 double * c_acc_mean = c_acc_means[k]; 00159 double * c_acc_cov = c_acc_covs[k]; 00160 00161 double tmp_k = c_acc_loglhood_K[k]; 00162 00163 c_weights[k] = tmp_k / N; 00164 00165 for (int d = 0;d < D;d++) { 00166 double tmp_mean = c_acc_mean[d] / tmp_k; 00167 c_mean[d] = tmp_mean; 00168 c_diag_cov[d] = c_acc_cov[d] / tmp_k - tmp_mean * tmp_mean; 00169 } 00170 } 00171 00172 return(acc_loglhood / N); 00173 00174 } 00175 00176 00177 void MOG_diag_EM_sup::ml_iterate() 00178 { 00179 using std::cout; 00180 using std::endl; 00181 using std::setw; 00182 using std::showpos; 00183 using std::noshowpos; 00184 using std::scientific; 00185 using std::fixed; 00186 using std::flush; 00187 using std::setprecision; 00188 00189 double avg_log_lhood_old = -1.0 * std::numeric_limits<double>::max(); 00190 00191 Real_Timer tt; 00192 00193 if (verbose) { 00194 cout << "MOG_diag_EM_sup::ml_iterate()" << endl; 00195 cout << setw(14) << "iteration"; 00196 cout << setw(14) << "avg_loglhood"; 00197 cout << setw(14) << "delta"; 00198 cout << setw(10) << "toc"; 00199 cout << endl; 00200 } 00201 00202 for (int i = 0; i < max_iter; i++) { 00203 sanitise_params(); 00204 update_internals(); 00205 00206 if (verbose) tt.tic(); 00207 double avg_log_lhood_new = ml_update_params(); 00208 00209 if (verbose) { 00210 double delta = avg_log_lhood_new - avg_log_lhood_old; 00211 00212 cout << noshowpos << fixed; 00213 cout << setw(14) << i; 00214 cout << showpos << scientific << setprecision(3); 00215 cout << setw(14) << avg_log_lhood_new; 00216 cout << setw(14) << delta; 00217 cout << noshowpos << fixed; 00218 cout << setw(10) << tt.toc(); 00219 cout << endl << flush; 00220 } 00221 00222 if (avg_log_lhood_new <= avg_log_lhood_old) break; 00223 00224 avg_log_lhood_old = avg_log_lhood_new; 00225 } 00226 } 00227 00228 00229 void MOG_diag_EM_sup::ml(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) 00230 { 00231 00232 it_assert(model_in.is_valid(), "MOG_diag_EM_sup::ml(): initial model not valid"); 00233 it_assert(check_array_uniformity(X_in), "MOG_diag_EM_sup::ml(): 'X' is empty or contains vectors of varying dimensionality"); 00234 it_assert((max_iter_in > 0), "MOG_diag_EM_sup::ml(): 'max_iter' needs to be greater than zero"); 00235 00236 verbose = verbose_in; 00237 00238 N = X_in.size(); 00239 00240 Array<vec> means_in = model_in.get_means(); 00241 Array<vec> diag_covs_in = model_in.get_diag_covs(); 00242 vec weights_in = model_in.get_weights(); 00243 00244 init(means_in, diag_covs_in, weights_in); 00245 00246 means_in.set_size(0); 00247 diag_covs_in.set_size(0); 00248 weights_in.set_size(0); 00249 00250 if (K > N) { 00251 it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N"); 00252 } 00253 else { 00254 if (K > N / 10) { 00255 it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N/10"); 00256 } 00257 } 00258 00259 var_floor = var_floor_in; 00260 weight_floor = weight_floor_in; 00261 00262 const double tiny = std::numeric_limits<double>::min(); 00263 if (var_floor < tiny) var_floor = tiny; 00264 if (weight_floor < tiny) weight_floor = tiny; 00265 if (weight_floor > 1.0 / K) weight_floor = 1.0 / K; 00266 00267 max_iter = max_iter_in; 00268 00269 tmpvecK.set_size(K); 00270 tmpvecD.set_size(D); 00271 acc_loglhood_K.set_size(K); 00272 00273 acc_means.set_size(K); 00274 for (int k = 0;k < K;k++) acc_means(k).set_size(D); 00275 acc_covs.set_size(K); 00276 for (int k = 0;k < K;k++) acc_covs(k).set_size(D); 00277 00278 c_X = enable_c_access(X_in); 00279 c_tmpvecK = enable_c_access(tmpvecK); 00280 c_tmpvecD = enable_c_access(tmpvecD); 00281 c_acc_loglhood_K = enable_c_access(acc_loglhood_K); 00282 c_acc_means = enable_c_access(acc_means); 00283 c_acc_covs = enable_c_access(acc_covs); 00284 00285 ml_iterate(); 00286 00287 model_in.init(means, diag_covs, weights); 00288 00289 disable_c_access(c_X); 00290 disable_c_access(c_tmpvecK); 00291 disable_c_access(c_tmpvecD); 00292 disable_c_access(c_acc_loglhood_K); 00293 disable_c_access(c_acc_means); 00294 disable_c_access(c_acc_covs); 00295 00296 00297 tmpvecK.set_size(0); 00298 tmpvecD.set_size(0); 00299 acc_loglhood_K.set_size(0); 00300 acc_means.set_size(0); 00301 acc_covs.set_size(0); 00302 00303 cleanup(); 00304 00305 } 00306 00307 void MOG_diag_EM_sup::map(MOG_diag &, MOG_diag &, Array<vec> &, int, double, 00308 double, double, bool) 00309 { 00310 it_error("MOG_diag_EM_sup::map(): not implemented yet"); 00311 } 00312 00313 00314 // 00315 // convenience functions 00316 00317 void MOG_diag_ML(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) 00318 { 00319 MOG_diag_EM_sup EM; 00320 EM.ml(model_in, X_in, max_iter_in, var_floor_in, weight_floor_in, verbose_in); 00321 } 00322 00323 void MOG_diag_MAP(MOG_diag &, MOG_diag &, Array<vec> &, int, double, double, 00324 double, bool) 00325 { 00326 it_error("MOG_diag_MAP(): not implemented yet"); 00327 } 00328 00329 } 00330
Generated on Thu Apr 23 20:06:48 2009 for IT++ by Doxygen 1.5.8