mlpack  2.0.1
hmm.hpp
Go to the documentation of this file.
1 
16 #ifndef __MLPACK_METHODS_HMM_HMM_HPP
17 #define __MLPACK_METHODS_HMM_HMM_HPP
18 
19 #include <mlpack/core.hpp>
20 
21 namespace mlpack {
22 namespace hmm {
23 
85 template<typename Distribution = distribution::DiscreteDistribution>
86 class HMM
87 {
88  public:
106  HMM(const size_t states = 0,
107  const Distribution emissions = Distribution(),
108  const double tolerance = 1e-5);
109 
137  HMM(const arma::vec& initial,
138  const arma::mat& transition,
139  const std::vector<Distribution>& emission,
140  const double tolerance = 1e-5);
141 
170  void Train(const std::vector<arma::mat>& dataSeq);
171 
193  void Train(const std::vector<arma::mat>& dataSeq,
194  const std::vector<arma::Row<size_t> >& stateSeq);
195 
214  double Estimate(const arma::mat& dataSeq,
215  arma::mat& stateProb,
216  arma::mat& forwardProb,
217  arma::mat& backwardProb,
218  arma::vec& scales) const;
219 
231  double Estimate(const arma::mat& dataSeq,
232  arma::mat& stateProb) const;
233 
245  void Generate(const size_t length,
246  arma::mat& dataSequence,
247  arma::Row<size_t>& stateSequence,
248  const size_t startState = 0) const;
249 
260  double Predict(const arma::mat& dataSeq,
261  arma::Row<size_t>& stateSeq) const;
262 
269  double LogLikelihood(const arma::mat& dataSeq) const;
270 
283  void Filter(const arma::mat& dataSeq,
284  arma::mat& filterSeq,
285  size_t ahead = 0) const;
286 
298  void Smooth(const arma::mat& dataSeq,
299  arma::mat& smoothSeq) const;
300 
302  const arma::vec& Initial() const { return initial; }
304  arma::vec& Initial() { return initial; }
305 
307  const arma::mat& Transition() const { return transition; }
309  arma::mat& Transition() { return transition; }
310 
312  const std::vector<Distribution>& Emission() const { return emission; }
314  std::vector<Distribution>& Emission() { return emission; }
315 
317  size_t Dimensionality() const { return dimensionality; }
319  size_t& Dimensionality() { return dimensionality; }
320 
322  double Tolerance() const { return tolerance; }
324  double& Tolerance() { return tolerance; }
325 
329  template<typename Archive>
330  void Serialize(Archive& ar, const unsigned int version);
331 
332  protected:
333  // Helper functions.
344  void Forward(const arma::mat& dataSeq,
345  arma::vec& scales,
346  arma::mat& forwardProb) const;
347 
359  void Backward(const arma::mat& dataSeq,
360  const arma::vec& scales,
361  arma::mat& backwardProb) const;
362 
364  std::vector<Distribution> emission;
365 
367  arma::mat transition;
368 
369  private:
371  arma::vec initial;
372 
375 
377  double tolerance;
378 };
379 
380 } // namespace hmm
381 } // namespace mlpack
382 
383 // Include implementation.
384 #include "hmm_impl.hpp"
385 
386 #endif
arma::vec & Initial()
Modify the vector of initial state probabilities.
Definition: hmm.hpp:304
size_t Dimensionality() const
Get the dimensionality of observations.
Definition: hmm.hpp:317
std::vector< Distribution > emission
Set of emission probability distributions; one for each state.
Definition: hmm.hpp:364
void Smooth(const arma::mat &dataSeq, arma::mat &smoothSeq) const
HMM smoothing.
Linear algebra utility functions, generally performed on matrices or vectors.
const arma::vec & Initial() const
Return the vector of initial state probabilities.
Definition: hmm.hpp:302
size_t & Dimensionality()
Set the dimensionality of observations.
Definition: hmm.hpp:319
const arma::mat & Transition() const
Return the transition matrix.
Definition: hmm.hpp:307
std::vector< Distribution > & Emission()
Return a modifiable emission probability matrix reference.
Definition: hmm.hpp:314
void Forward(const arma::mat &dataSeq, arma::vec &scales, arma::mat &forwardProb) const
The Forward algorithm (part of the Forward-Backward algorithm).
double tolerance
Tolerance of Baum-Welch algorithm.
Definition: hmm.hpp:377
arma::vec initial
Initial state probability vector.
Definition: hmm.hpp:371
double & Tolerance()
Modify the tolerance of the Baum-Welch algorithm.
Definition: hmm.hpp:324
double LogLikelihood(const arma::mat &dataSeq) const
Compute the log-likelihood of the given data sequence.
double Tolerance() const
Get the tolerance of the Baum-Welch algorithm.
Definition: hmm.hpp:322
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
A class that represents a Hidden Markov Model with an arbitrary type of emission distribution.
Definition: hmm.hpp:86
const std::vector< Distribution > & Emission() const
Return the emission distributions.
Definition: hmm.hpp:312
void Generate(const size_t length, arma::mat &dataSequence, arma::Row< size_t > &stateSequence, const size_t startState=0) const
Generate a random data sequence of the given length.
double Predict(const arma::mat &dataSeq, arma::Row< size_t > &stateSeq) const
Compute the most probable hidden state sequence for the given data sequence, using the Viterbi algori...
double Estimate(const arma::mat &dataSeq, arma::mat &stateProb, arma::mat &forwardProb, arma::mat &backwardProb, arma::vec &scales) const
Estimate the probabilities of each hidden state at each time step for each given data observation...
void Backward(const arma::mat &dataSeq, const arma::vec &scales, arma::mat &backwardProb) const
The Backward algorithm (part of the Forward-Backward algorithm).
void Train(const std::vector< arma::mat > &dataSeq)
Train the model using the Baum-Welch algorithm, with only the given unlabeled observations.
size_t dimensionality
Dimensionality of observations.
Definition: hmm.hpp:374
arma::mat & Transition()
Return a modifiable transition matrix reference.
Definition: hmm.hpp:309
void Serialize(Archive &ar, const unsigned int version)
Serialize the object.
void Filter(const arma::mat &dataSeq, arma::mat &filterSeq, size_t ahead=0) const
HMM filtering.
HMM(const size_t states=0, const Distribution emissions=Distribution(), const double tolerance=1e-5)
Create the Hidden Markov Model with the given number of hidden states and the given default distribut...
arma::mat transition
Transition probability matrix.
Definition: hmm.hpp:367