MLPACK
1.0.4
|
00001 00023 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP 00024 #define __MLPACK_METHODS_GMM_EM_FIT_HPP 00025 00026 #include <mlpack/core.hpp> 00027 00028 // Default clustering mechanism. 00029 #include <mlpack/methods/kmeans/kmeans.hpp> 00030 00031 namespace mlpack { 00032 namespace gmm { 00033 00047 template<typename InitialClusteringType = kmeans::KMeans<> > 00048 class EMFit 00049 { 00050 public: 00055 EMFit(InitialClusteringType clusterer = InitialClusteringType()) : 00056 clusterer(clusterer) { /* Nothing to do. */ } 00057 00068 void Estimate(const arma::mat& observations, 00069 std::vector<arma::vec>& means, 00070 std::vector<arma::mat>& covariances, 00071 arma::vec& weights); 00072 00085 void Estimate(const arma::mat& observations, 00086 const arma::vec& probabilities, 00087 std::vector<arma::vec>& means, 00088 std::vector<arma::mat>& covariances, 00089 arma::vec& weights); 00090 00091 private: 00102 void InitialClustering(const arma::mat& observations, 00103 std::vector<arma::vec>& means, 00104 std::vector<arma::mat>& covariances, 00105 arma::vec& weights); 00106 00117 double LogLikelihood(const arma::mat& data, 00118 const std::vector<arma::vec>& means, 00119 const std::vector<arma::mat>& covariances, 00120 const arma::vec& weights) const; 00121 00122 InitialClusteringType clusterer; 00123 }; 00124 00125 }; // namespace gmm 00126 }; // namespace mlpack 00127 00128 // Include implementation. 00129 #include "em_fit_impl.hpp" 00130 00131 #endif