00001 00030 #ifndef MOG_GENERIC_H 00031 #define MOG_GENERIC_H 00032 00033 #include <itpp/base/vec.h> 00034 #include <itpp/base/mat.h> 00035 #include <itpp/base/array.h> 00036 00037 00038 namespace itpp 00039 { 00040 00057 class MOG_generic 00058 { 00059 00060 public: 00061 00067 MOG_generic() { init(); } 00068 00072 MOG_generic(const std::string &name_in) { load(name_in); } 00073 00079 MOG_generic(const int &K_in, const int &D_in, bool full_in = false) { init(K_in, D_in, full_in); } 00080 00088 MOG_generic(Array<vec> &means_in, bool full_in = false) { init(means_in, full_in); } 00089 00096 MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); } 00097 00104 MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); } 00105 00107 virtual ~MOG_generic() { cleanup(); } 00108 00113 void init(); 00114 00120 void init(const int &K_in, const int &D_in, bool full_in = false); 00121 00129 void init(Array<vec> &means_in, bool full_in = false); 00130 00137 void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in); 00138 00145 void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in); 00146 00151 virtual void cleanup(); 00152 00154 bool is_valid() const { return valid; } 00155 00157 bool is_full() const { return full; } 00158 00160 int get_K() const { if (valid) return(K); else return(0); } 00161 00163 int get_D() const { if (valid) return(D); else return(0); } 00164 00166 vec get_weights() const { vec tmp; if (valid) { tmp = weights; } return tmp; } 00167 00169 Array<vec> get_means() const { Array<vec> tmp; if (valid) { tmp = means; } return tmp; } 00170 00172 Array<vec> get_diag_covs() const { Array<vec> tmp; if (valid && !full) { tmp = diag_covs; } return tmp; } 00173 00175 Array<mat> get_full_covs() const { Array<mat> tmp; if (valid && full) { tmp = full_covs; } return tmp; } 00176 00180 void set_means(Array<vec> &means_in); 00181 00185 void set_diag_covs(Array<vec> &diag_covs_in); 00186 00190 void set_full_covs(Array<mat> &full_covs_in); 00191 00195 void set_weights(vec &weights_in); 00196 00198 void set_means_zero(); 00199 00201 void set_diag_covs_unity(); 00202 00204 void set_full_covs_unity(); 00205 00207 void set_weights_uniform(); 00208 00214 void set_checks(bool do_checks_in) { do_checks = do_checks_in; } 00215 00219 void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; } 00220 00224 virtual void load(const std::string &name_in); 00225 00229 virtual void save(const std::string &name_in) const; 00230 00247 virtual void join(const MOG_generic &B_in); 00248 00256 virtual void convert_to_diag(); 00257 00263 virtual void convert_to_full(); 00264 00266 virtual double log_lhood_single_gaus(const vec &x_in, const int k); 00267 00269 virtual double log_lhood(const vec &x_in); 00270 00272 virtual double lhood(const vec &x_in); 00273 00275 virtual double avg_log_lhood(const Array<vec> &X_in); 00276 00277 protected: 00278 00280 bool do_checks; 00281 00283 bool valid; 00284 00286 bool full; 00287 00289 bool paranoid; 00290 00292 int K; 00293 00295 int D; 00296 00298 Array<vec> means; 00299 00301 Array<vec> diag_covs; 00302 00304 Array<mat> full_covs; 00305 00307 vec weights; 00308 00310 double log_max_K; 00311 00317 vec log_det_etc; 00318 00320 vec log_weights; 00321 00323 Array<mat> full_covs_inv; 00324 00326 Array<vec> diag_covs_inv_etc; 00327 00329 bool check_size(const vec &x_in) const; 00330 00332 bool check_size(const Array<vec> &X_in) const; 00333 00335 bool check_array_uniformity(const Array<vec> & A) const; 00336 00338 void set_means_internal(Array<vec> &means_in); 00340 void set_diag_covs_internal(Array<vec> &diag_covs_in); 00342 void set_full_covs_internal(Array<mat> &full_covs_in); 00344 void set_weights_internal(vec &_weigths); 00345 00347 void set_means_zero_internal(); 00349 void set_diag_covs_unity_internal(); 00351 void set_full_covs_unity_internal(); 00353 void set_weights_uniform_internal(); 00354 00356 void convert_to_diag_internal(); 00358 void convert_to_full_internal(); 00359 00361 virtual void setup_means(); 00362 00364 virtual void setup_covs(); 00365 00367 virtual void setup_weights(); 00368 00370 virtual void setup_misc(); 00371 00373 virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k); 00375 virtual double log_lhood_internal(const vec &x_in); 00377 virtual double lhood_internal(const vec &x_in); 00378 00379 private: 00380 vec tmpvecD; 00381 vec tmpvecK; 00382 00383 }; 00384 00385 } // namespace itpp 00386 00387 #endif // #ifndef MOG_GENERIC_H
Generated on Thu Apr 23 20:04:05 2009 for IT++ by Doxygen 1.5.8