MLPACK  1.0.4
gmm.hpp
Go to the documentation of this file.
00001 
00023 #ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP
00024 #define __MLPACK_METHODS_MOG_MOG_EM_HPP
00025 
00026 #include <mlpack/core.hpp>
00027 
00028 // This is the default fitting method class.
00029 #include "em_fit.hpp"
00030 
00031 namespace mlpack {
00032 namespace gmm  {
00033 
00088 template<typename FittingType = EMFit<> >
00089 class GMM
00090 {
00091  private:
00093   size_t gaussians;
00095   size_t dimensionality;
00097   std::vector<arma::vec> means;
00099   std::vector<arma::mat> covariances;
00101   arma::vec weights;
00102 
00103  public:
00107   GMM() :
00108       gaussians(0),
00109       dimensionality(0),
00110       localFitter(FittingType()),
00111       fitter(localFitter)
00112   {
00113     // Warn the user.  They probably don't want to do this.  If this constructor
00114     // is being used (because it is required by some template classes), the user
00115     // should know that it is potentially dangerous.
00116     Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
00117         << "unless parameters are set." << std::endl;
00118   }
00119 
00127   GMM(const size_t gaussians, const size_t dimensionality) :
00128       gaussians(gaussians),
00129       dimensionality(dimensionality),
00130       means(gaussians, arma::vec(dimensionality)),
00131       covariances(gaussians, arma::mat(dimensionality, dimensionality)),
00132       weights(gaussians),
00133       localFitter(FittingType()),
00134       fitter(localFitter) { /* Nothing to do. */ }
00135 
00146   GMM(const size_t gaussians,
00147       const size_t dimensionality,
00148       FittingType& fitter) :
00149       gaussians(gaussians),
00150       dimensionality(dimensionality),
00151       means(gaussians, arma::vec(dimensionality)),
00152       covariances(gaussians, arma::mat(dimensionality, dimensionality)),
00153       weights(gaussians),
00154       fitter(fitter) { /* Nothing to do. */ }
00155 
00163   GMM(const std::vector<arma::vec>& means,
00164       const std::vector<arma::mat>& covariances,
00165       const arma::vec& weights) :
00166       gaussians(means.size()),
00167       dimensionality((!means.empty()) ? means[0].n_elem : 0),
00168       means(means),
00169       covariances(covariances),
00170       weights(weights),
00171       localFitter(FittingType()),
00172       fitter(localFitter) { /* Nothing to do. */ }
00173 
00183   GMM(const std::vector<arma::vec>& means,
00184       const std::vector<arma::mat>& covariances,
00185       const arma::vec& weights,
00186       FittingType& fitter) :
00187       gaussians(means.size()),
00188       dimensionality((!means.empty()) ? means[0].n_elem : 0),
00189       means(means),
00190       covariances(covariances),
00191       weights(weights),
00192       fitter(fitter) { /* Nothing to do. */ }
00193 
00197   template<typename OtherFittingType>
00198   GMM(const GMM<OtherFittingType>& other);
00199 
00204   GMM(const GMM& other);
00205 
00209   template<typename OtherFittingType>
00210   GMM& operator=(const GMM<OtherFittingType>& other);
00211 
00216   GMM& operator=(const GMM& other);
00217 
00219   size_t Gaussians() const { return gaussians; }
00222   size_t& Gaussians() { return gaussians; }
00223 
00225   size_t Dimensionality() const { return dimensionality; }
00228   size_t& Dimensionality() { return dimensionality; }
00229 
00231   const std::vector<arma::vec>& Means() const { return means; }
00233   std::vector<arma::vec>& Means() { return means; }
00234 
00236   const std::vector<arma::mat>& Covariances() const { return covariances; }
00238   std::vector<arma::mat>& Covariances() { return covariances; }
00239 
00241   const arma::vec& Weights() const { return weights; }
00243   arma::vec& Weights() { return weights; }
00244 
00246   const FittingType& Fitter() const { return fitter; }
00248   FittingType& Fitter() { return fitter; }
00249 
00256   double Probability(const arma::vec& observation) const;
00257 
00265   double Probability(const arma::vec& observation,
00266                      const size_t component) const;
00267 
00274   arma::vec Random() const;
00275 
00291   double Estimate(const arma::mat& observations,
00292                   const size_t trials = 1);
00293 
00311   double Estimate(const arma::mat& observations,
00312                   const arma::vec& probabilities,
00313                   const size_t trials = 1);
00314 
00331   void Classify(const arma::mat& observations,
00332                 arma::Col<size_t>& labels) const;
00333 
00334  private:
00344   double LogLikelihood(const arma::mat& dataPoints,
00345                        const std::vector<arma::vec>& means,
00346                        const std::vector<arma::mat>& covars,
00347                        const arma::vec& weights) const;
00348 
00350   FittingType localFitter;
00351 
00353   FittingType& fitter;
00354 };
00355 
00356 }; // namespace gmm
00357 }; // namespace mlpack
00358 
00359 // Include implementation.
00360 #include "gmm_impl.hpp"
00361 
00362 #endif