00001 00030 #include <itpp/base/math/log_exp.h> 00031 #include <itpp/stat/mog_diag.h> 00032 #include <cstdlib> 00033 00034 00035 namespace itpp 00036 { 00037 00038 double MOG_diag::log_lhood_single_gaus_internal(const double * c_x_in, const int k) const 00039 { 00040 00041 const double * c_mean = c_means[k]; 00042 const double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k]; 00043 00044 double acc = 0.0; 00045 00046 for (int d = 0; d < D; d++) { 00047 double tmp_val = c_x_in[d] - c_mean[d]; 00048 acc += (tmp_val * tmp_val) * c_diag_cov_inv_etc[d]; 00049 } 00050 return(c_log_det_etc[k] - acc); 00051 } 00052 00053 00054 double MOG_diag::log_lhood_single_gaus_internal(const vec &x_in, const int k) const 00055 { 00056 return log_lhood_single_gaus_internal(x_in._data(), k); 00057 } 00058 00059 00060 double MOG_diag::log_lhood_single_gaus(const double * c_x_in, const int k) const 00061 { 00062 if (do_checks) { 00063 it_assert(valid, "MOG_diag::log_lhood_single_gaus(): model not valid"); 00064 it_assert(((k >= 0) && (k < K)), "MOG::log_lhood_single_gaus(): k specifies a non-existant Gaussian"); 00065 } 00066 return log_lhood_single_gaus_internal(c_x_in, k); 00067 } 00068 00069 00070 double MOG_diag::log_lhood_single_gaus(const vec &x_in, const int k) const 00071 { 00072 if (do_checks) { 00073 it_assert(valid, "MOG_diag::log_lhood_single_gaus(): model not valid"); 00074 it_assert(check_size(x_in), "MOG_diag::log_lhood_single_gaus(): x has wrong dimensionality"); 00075 it_assert(((k >= 0) && (k < K)), "MOG::log_lhood_single_gaus(): k specifies a non-existant Gaussian"); 00076 } 00077 return log_lhood_single_gaus_internal(x_in._data(), k); 00078 } 00079 00080 00081 double MOG_diag::log_lhood_internal(const double * c_x_in) 00082 { 00083 00084 bool danger = paranoid; 00085 00086 for (int k = 0;k < K;k++) { 00087 double tmp = c_log_weights[k] + log_lhood_single_gaus_internal(c_x_in, k); 00088 c_tmpvecK[k] = tmp; 00089 00090 if (tmp >= log_max_K) danger = true; 00091 } 00092 00093 00094 if (danger) { 00095 double log_sum = c_tmpvecK[0]; 00096 for (int k = 1; k < K; k++) log_sum = log_add(log_sum, c_tmpvecK[k]); 00097 return(log_sum); 00098 } 00099 else { 00100 double sum = 0.0; 00101 for (int k = 0;k < K;k++) sum += std::exp(c_tmpvecK[k]); 00102 return(std::log(sum)); 00103 } 00104 } 00105 00106 00107 double MOG_diag::log_lhood_internal(const vec &x_in) 00108 { 00109 return log_lhood_internal(x_in._data()); 00110 } 00111 00112 00113 double MOG_diag::log_lhood(const vec &x_in) 00114 { 00115 if (do_checks) { 00116 it_assert(valid, "MOG_diag::log_lhood(): model not valid"); 00117 it_assert(check_size(x_in), "MOG_diag::log_lhood(): x has wrong dimensionality"); 00118 } 00119 return log_lhood_internal(x_in._data()); 00120 } 00121 00122 00123 double MOG_diag::log_lhood(const double * c_x_in) 00124 { 00125 if (do_checks) { 00126 it_assert(valid, "MOG_diag::log_lhood(): model not valid"); 00127 it_assert((c_x_in != 0), "MOG_diag::log_lhood(): c_x_in is a null pointer"); 00128 } 00129 00130 return log_lhood_internal(c_x_in); 00131 } 00132 00133 00134 double MOG_diag::lhood_internal(const double * c_x_in) 00135 { 00136 00137 bool danger = paranoid; 00138 00139 for (int k = 0;k < K;k++) { 00140 double tmp = c_log_weights[k] + log_lhood_single_gaus_internal(c_x_in, k); 00141 c_tmpvecK[k] = tmp; 00142 00143 if (tmp >= log_max_K) danger = true; 00144 } 00145 00146 00147 if (danger) { 00148 double log_sum = c_tmpvecK[0]; 00149 for (int k = 1; k < K; k++) log_sum = log_add(log_sum, c_tmpvecK[k]); 00150 return(trunc_exp(log_sum)); 00151 } 00152 else { 00153 double sum = 0.0; 00154 for (int k = 0;k < K;k++) sum += std::exp(c_tmpvecK[k]); 00155 return(sum); 00156 } 00157 } 00158 00159 double MOG_diag::lhood_internal(const vec &x_in) { return lhood_internal(x_in._data()); } 00160 00161 double MOG_diag::lhood(const vec &x_in) 00162 { 00163 if (do_checks) { 00164 it_assert(valid, "MOG_diag::lhood(): model not valid"); 00165 it_assert(check_size(x_in), "MOG_diag::lhood(): x has wrong dimensionality"); 00166 } 00167 return lhood_internal(x_in._data()); 00168 } 00169 00170 00171 double MOG_diag::lhood(const double * c_x_in) 00172 { 00173 if (do_checks) { 00174 it_assert(valid, "MOG_diag::lhood(): model not valid"); 00175 it_assert((c_x_in != 0), "MOG_diag::lhood(): c_x_in is a null pointer"); 00176 } 00177 00178 return lhood_internal(c_x_in); 00179 } 00180 00181 00182 double MOG_diag::avg_log_lhood(const double ** c_x_in, const int N) 00183 { 00184 if (do_checks) { 00185 it_assert(valid, "MOG_diag::avg_log_lhood(): model not valid"); 00186 it_assert((c_x_in != 0), "MOG_diag::avg_log_lhood(): c_x_in is a null pointer"); 00187 it_assert((N >= 0), "MOG_diag::avg_log_lhood(): N is zero or negative"); 00188 } 00189 00190 double acc = 0.0; 00191 for (int n = 0;n < N;n++) acc += log_lhood_internal(c_x_in[n]); 00192 return(acc / N); 00193 } 00194 00195 00196 double MOG_diag::avg_log_lhood(const Array<vec> &X_in) 00197 { 00198 if (do_checks) { 00199 it_assert(valid, "MOG_diag::avg_log_lhood(): model not valid"); 00200 it_assert(check_size(X_in), "MOG_diag::avg_log_lhood(): X is empty or at least one vector has the wrong dimensionality"); 00201 } 00202 const int N = X_in.size(); 00203 double acc = 0.0; 00204 for (int n = 0;n < N;n++) acc += log_lhood_internal(X_in(n)._data()); 00205 return(acc / N); 00206 } 00207 00208 void MOG_diag::zero_all_ptrs() 00209 { 00210 c_means = 0; 00211 c_diag_covs = 0; 00212 c_diag_covs_inv_etc = 0; 00213 c_weights = 0; 00214 c_log_weights = 0; 00215 c_log_det_etc = 0; 00216 c_tmpvecK = 0; 00217 } 00218 00219 00220 void MOG_diag::free_all_ptrs() 00221 { 00222 c_means = disable_c_access(c_means); 00223 c_diag_covs = disable_c_access(c_diag_covs); 00224 c_diag_covs_inv_etc = disable_c_access(c_diag_covs_inv_etc); 00225 c_weights = disable_c_access(c_weights); 00226 c_log_weights = disable_c_access(c_log_weights); 00227 c_log_det_etc = disable_c_access(c_log_det_etc); 00228 c_tmpvecK = disable_c_access(c_tmpvecK); 00229 } 00230 00231 00232 void MOG_diag::setup_means() 00233 { 00234 MOG_generic::setup_means(); 00235 disable_c_access(c_means); 00236 c_means = enable_c_access(means); 00237 } 00238 00239 00240 void MOG_diag::setup_covs() 00241 { 00242 MOG_generic::setup_covs(); 00243 if (full) return; 00244 00245 disable_c_access(c_diag_covs); 00246 disable_c_access(c_diag_covs_inv_etc); 00247 disable_c_access(c_log_det_etc); 00248 00249 c_diag_covs = enable_c_access(diag_covs); 00250 c_diag_covs_inv_etc = enable_c_access(diag_covs_inv_etc); 00251 c_log_det_etc = enable_c_access(log_det_etc); 00252 } 00253 00254 00255 void MOG_diag::setup_weights() 00256 { 00257 MOG_generic::setup_weights(); 00258 00259 disable_c_access(c_weights); 00260 disable_c_access(c_log_weights); 00261 00262 c_weights = enable_c_access(weights); 00263 c_log_weights = enable_c_access(log_weights); 00264 } 00265 00266 00267 void MOG_diag::setup_misc() 00268 { 00269 disable_c_access(c_tmpvecK); 00270 tmpvecK.set_size(K); 00271 c_tmpvecK = enable_c_access(tmpvecK); 00272 00273 MOG_generic::setup_misc(); 00274 if (full) convert_to_diag_internal(); 00275 } 00276 00277 00278 void MOG_diag::load(const std::string &name_in) 00279 { 00280 MOG_generic::load(name_in); 00281 if (full) convert_to_diag(); 00282 } 00283 00284 00285 double ** MOG_diag::enable_c_access(Array<vec> & A_in) 00286 { 00287 int rows = A_in.size(); 00288 double ** A = (double **)std::malloc(rows * sizeof(double *)); 00289 if (A) for (int row = 0;row < rows;row++) A[row] = A_in(row)._data(); 00290 return(A); 00291 } 00292 00293 int ** MOG_diag::enable_c_access(Array<ivec> & A_in) 00294 { 00295 int rows = A_in.size(); 00296 int ** A = (int **)std::malloc(rows * sizeof(int *)); 00297 if (A) for (int row = 0;row < rows;row++) A[row] = A_in(row)._data(); 00298 return(A); 00299 } 00300 00301 double ** MOG_diag::disable_c_access(double ** A_in) { if (A_in) std::free(A_in); return(0); } 00302 int ** MOG_diag::disable_c_access(int ** A_in) { if (A_in) std::free(A_in); return(0); } 00303 00304 double * MOG_diag::enable_c_access(vec & v_in) { return v_in._data(); } 00305 int * MOG_diag::enable_c_access(ivec & v_in) { return v_in._data(); } 00306 00307 double * MOG_diag::disable_c_access(double *) { return(0); } 00308 int * MOG_diag::disable_c_access(int *) { return(0); } 00309 00310 }
Generated on Thu Apr 23 20:06:48 2009 for IT++ by Doxygen 1.5.8