IT++ Logo

mog_generic.h

Go to the documentation of this file.
00001 
00030 #ifndef MOG_GENERIC_H
00031 #define MOG_GENERIC_H
00032 
00033 #include <itpp/base/vec.h>
00034 #include <itpp/base/mat.h>
00035 #include <itpp/base/array.h>
00036 
00037 
00038 namespace itpp
00039 {
00040 
00057 class MOG_generic
00058 {
00059 
00060 public:
00061 
00067   MOG_generic() { init(); }
00068 
00072   MOG_generic(const std::string &name_in) { load(name_in); }
00073 
00079   MOG_generic(const int &K_in, const int &D_in, bool full_in = false) { init(K_in, D_in, full_in); }
00080 
00088   MOG_generic(Array<vec> &means_in, bool full_in = false) { init(means_in, full_in); }
00089 
00096   MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); }
00097 
00104   MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); }
00105 
00107   virtual ~MOG_generic() { cleanup(); }
00108 
00113   void init();
00114 
00120   void init(const int &K_in, const int &D_in, bool full_in = false);
00121 
00129   void init(Array<vec> &means_in, bool full_in = false);
00130 
00137   void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in);
00138 
00145   void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in);
00146 
00151   virtual void cleanup();
00152 
00154   bool is_valid() const { return valid; }
00155 
00157   bool is_full() const { return full; }
00158 
00160   int get_K() const { if (valid) return(K); else return(0); }
00161 
00163   int get_D() const { if (valid) return(D); else return(0); }
00164 
00166   vec get_weights() const { vec tmp;  if (valid) { tmp = weights; } return tmp; }
00167 
00169   Array<vec> get_means() const { Array<vec> tmp; if (valid) { tmp = means; } return tmp; }
00170 
00172   Array<vec> get_diag_covs() const { Array<vec> tmp; if (valid && !full) { tmp = diag_covs; } return tmp; }
00173 
00175   Array<mat> get_full_covs() const { Array<mat> tmp; if (valid && full) { tmp = full_covs; } return tmp; }
00176 
00180   void set_means(Array<vec> &means_in);
00181 
00185   void set_diag_covs(Array<vec> &diag_covs_in);
00186 
00190   void set_full_covs(Array<mat> &full_covs_in);
00191 
00195   void set_weights(vec &weights_in);
00196 
00198   void set_means_zero();
00199 
00201   void set_diag_covs_unity();
00202 
00204   void set_full_covs_unity();
00205 
00207   void set_weights_uniform();
00208 
00214   void set_checks(bool do_checks_in) { do_checks = do_checks_in; }
00215 
00219   void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; }
00220 
00224   virtual void load(const std::string &name_in);
00225 
00229   virtual void save(const std::string &name_in) const;
00230 
00247   virtual void join(const MOG_generic &B_in);
00248 
00256   virtual void convert_to_diag();
00257 
00263   virtual void convert_to_full();
00264 
00266   virtual double log_lhood_single_gaus(const vec &x_in, const int k);
00267 
00269   virtual double log_lhood(const vec &x_in);
00270 
00272   virtual double lhood(const vec &x_in);
00273 
00275   virtual double avg_log_lhood(const Array<vec> &X_in);
00276 
00277 protected:
00278 
00280   bool do_checks;
00281 
00283   bool valid;
00284 
00286   bool full;
00287 
00289   bool paranoid;
00290 
00292   int K;
00293 
00295   int D;
00296 
00298   Array<vec> means;
00299 
00301   Array<vec> diag_covs;
00302 
00304   Array<mat> full_covs;
00305 
00307   vec weights;
00308 
00310   double log_max_K;
00311 
00317   vec log_det_etc;
00318 
00320   vec log_weights;
00321 
00323   Array<mat> full_covs_inv;
00324 
00326   Array<vec> diag_covs_inv_etc;
00327 
00329   bool check_size(const vec &x_in) const;
00330 
00332   bool check_size(const Array<vec> &X_in) const;
00333 
00335   bool check_array_uniformity(const Array<vec> & A) const;
00336 
00338   void set_means_internal(Array<vec> &means_in);
00340   void set_diag_covs_internal(Array<vec> &diag_covs_in);
00342   void set_full_covs_internal(Array<mat> &full_covs_in);
00344   void set_weights_internal(vec &_weigths);
00345 
00347   void set_means_zero_internal();
00349   void set_diag_covs_unity_internal();
00351   void set_full_covs_unity_internal();
00353   void set_weights_uniform_internal();
00354 
00356   void convert_to_diag_internal();
00358   void convert_to_full_internal();
00359 
00361   virtual void setup_means();
00362 
00364   virtual void setup_covs();
00365 
00367   virtual void setup_weights();
00368 
00370   virtual void setup_misc();
00371 
00373   virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k);
00375   virtual double log_lhood_internal(const vec &x_in);
00377   virtual double lhood_internal(const vec &x_in);
00378 
00379 private:
00380   vec tmpvecD;
00381   vec tmpvecK;
00382 
00383 };
00384 
00385 } // namespace itpp
00386 
00387 #endif // #ifndef MOG_GENERIC_H
SourceForge Logo

Generated on Thu Apr 23 20:04:05 2009 for IT++ by Doxygen 1.5.8