IT++ Logo

mog_generic.h

Go to the documentation of this file.
00001 
00030 #ifndef MOG_GENERIC_H
00031 #define MOG_GENERIC_H
00032 
00033 #include <itpp/base/vec.h>
00034 #include <itpp/base/mat.h>
00035 #include <itpp/base/array.h>
00036 
00037 
00038 namespace itpp {
00039 
00056   class MOG_generic {
00057 
00058   public:
00059 
00065     MOG_generic() { init(); }
00066 
00070     MOG_generic(const std::string &name_in) { load(name_in); }
00071 
00077     MOG_generic(const int &K_in, const int &D_in, bool full_in=false) { init(K_in, D_in, full_in); }
00078 
00086     MOG_generic(Array<vec> &means_in, bool full_in=false) { init(means_in, full_in); }
00087 
00094     MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); }
00095 
00102     MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); }
00103 
00105     virtual ~MOG_generic() { cleanup(); }
00106 
00111     void init();
00112 
00118     void init(const int &K_in, const int &D_in, bool full_in=false);
00119 
00127     void init(Array<vec> &means_in, bool full_in=false);
00128 
00135     void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in);
00136 
00143     void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in);
00144 
00149     virtual void cleanup();
00150 
00152     bool is_valid() const { return valid; }
00153 
00155     bool is_full() const { return full; }
00156 
00158     int get_K() const { if(valid) return(K); else return(0); }
00159 
00161     int get_D() const { if(valid) return(D); else return(0); }
00162 
00164     vec get_weights() const { vec tmp;  if(valid) { tmp = weights; } return tmp; }
00165 
00167     Array<vec> get_means() const { Array<vec> tmp; if(valid) { tmp = means; } return tmp; }
00168 
00170     Array<vec> get_diag_covs() const { Array<vec> tmp; if(valid && !full) { tmp = diag_covs; } return tmp; }
00171 
00173     Array<mat> get_full_covs() const { Array<mat> tmp; if(valid && full) { tmp = full_covs; } return tmp; }
00174 
00178     void set_means(Array<vec> &means_in);
00179 
00183     void set_diag_covs(Array<vec> &diag_covs_in);
00184 
00188     void set_full_covs(Array<mat> &full_covs_in);
00189 
00193     void set_weights(vec &weights_in);
00194 
00196     void set_means_zero();
00197 
00199     void set_diag_covs_unity();
00200 
00202     void set_full_covs_unity();
00203 
00205     void set_weights_uniform();
00206 
00212     void set_checks(bool do_checks_in) { do_checks = do_checks_in; }
00213 
00217     void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; }
00218 
00222     virtual void load(const std::string &name_in);
00223 
00227     virtual void save(const std::string &name_in) const;
00228 
00245     virtual void join(const MOG_generic &B_in);
00246 
00254     virtual void convert_to_diag();
00255 
00261     virtual void convert_to_full();
00262 
00264     virtual double log_lhood_single_gaus(const vec &x_in, const int k);
00265 
00267     virtual double log_lhood(const vec &x_in);
00268 
00270     virtual double lhood(const vec &x_in);
00271 
00273     virtual double avg_log_lhood(const Array<vec> &X_in);
00274 
00275   protected:
00276 
00278     bool do_checks;
00279 
00281     bool valid;
00282 
00284     bool full;
00285 
00287     bool paranoid;
00288 
00290     int K;
00291 
00293     int D;
00294 
00296     Array<vec> means;
00297 
00299     Array<vec> diag_covs;
00300 
00302     Array<mat> full_covs;
00303 
00305     vec weights;
00306 
00308     double log_max_K;
00309 
00315     vec log_det_etc;
00316 
00318     vec log_weights;
00319 
00321     Array<mat> full_covs_inv;
00322 
00324     Array<vec> diag_covs_inv_etc;
00325 
00327     bool check_size(const vec &x_in) const;
00328 
00330     bool check_size(const Array<vec> &X_in) const;
00331 
00333     bool check_array_uniformity(const Array<vec> & A) const;
00334 
00336     void set_means_internal(Array<vec> &means_in);
00338     void set_diag_covs_internal(Array<vec> &diag_covs_in);
00340     void set_full_covs_internal(Array<mat> &full_covs_in);
00342     void set_weights_internal(vec &_weigths);
00343 
00345     void set_means_zero_internal();
00347     void set_diag_covs_unity_internal();
00349     void set_full_covs_unity_internal();
00351     void set_weights_uniform_internal();
00352 
00354     void convert_to_diag_internal();
00356     void convert_to_full_internal();
00357 
00359     virtual void setup_means();
00360 
00362     virtual void setup_covs();
00363 
00365     virtual void setup_weights();
00366 
00368     virtual void setup_misc();
00369 
00371     virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k);
00373     virtual double log_lhood_internal(const vec &x_in);
00375     virtual double lhood_internal(const vec &x_in);
00376 
00377   private:
00378     vec tmpvecD;
00379     vec tmpvecK;
00380 
00381   };
00382 
00383 } // namespace itpp
00384 
00385 #endif // #ifndef MOG_GENERIC_H
SourceForge Logo

Generated on Sat Apr 19 10:42:06 2008 for IT++ by Doxygen 1.5.5