MLPACK
1.0.4
|
00001 00023 #ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP 00024 #define __MLPACK_METHODS_MOG_MOG_EM_HPP 00025 00026 #include <mlpack/core.hpp> 00027 00028 // This is the default fitting method class. 00029 #include "em_fit.hpp" 00030 00031 namespace mlpack { 00032 namespace gmm { 00033 00088 template<typename FittingType = EMFit<> > 00089 class GMM 00090 { 00091 private: 00093 size_t gaussians; 00095 size_t dimensionality; 00097 std::vector<arma::vec> means; 00099 std::vector<arma::mat> covariances; 00101 arma::vec weights; 00102 00103 public: 00107 GMM() : 00108 gaussians(0), 00109 dimensionality(0), 00110 localFitter(FittingType()), 00111 fitter(localFitter) 00112 { 00113 // Warn the user. They probably don't want to do this. If this constructor 00114 // is being used (because it is required by some template classes), the user 00115 // should know that it is potentially dangerous. 00116 Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail " 00117 << "unless parameters are set." << std::endl; 00118 } 00119 00127 GMM(const size_t gaussians, const size_t dimensionality) : 00128 gaussians(gaussians), 00129 dimensionality(dimensionality), 00130 means(gaussians, arma::vec(dimensionality)), 00131 covariances(gaussians, arma::mat(dimensionality, dimensionality)), 00132 weights(gaussians), 00133 localFitter(FittingType()), 00134 fitter(localFitter) { /* Nothing to do. */ } 00135 00146 GMM(const size_t gaussians, 00147 const size_t dimensionality, 00148 FittingType& fitter) : 00149 gaussians(gaussians), 00150 dimensionality(dimensionality), 00151 means(gaussians, arma::vec(dimensionality)), 00152 covariances(gaussians, arma::mat(dimensionality, dimensionality)), 00153 weights(gaussians), 00154 fitter(fitter) { /* Nothing to do. */ } 00155 00163 GMM(const std::vector<arma::vec>& means, 00164 const std::vector<arma::mat>& covariances, 00165 const arma::vec& weights) : 00166 gaussians(means.size()), 00167 dimensionality((!means.empty()) ? means[0].n_elem : 0), 00168 means(means), 00169 covariances(covariances), 00170 weights(weights), 00171 localFitter(FittingType()), 00172 fitter(localFitter) { /* Nothing to do. */ } 00173 00183 GMM(const std::vector<arma::vec>& means, 00184 const std::vector<arma::mat>& covariances, 00185 const arma::vec& weights, 00186 FittingType& fitter) : 00187 gaussians(means.size()), 00188 dimensionality((!means.empty()) ? means[0].n_elem : 0), 00189 means(means), 00190 covariances(covariances), 00191 weights(weights), 00192 fitter(fitter) { /* Nothing to do. */ } 00193 00197 template<typename OtherFittingType> 00198 GMM(const GMM<OtherFittingType>& other); 00199 00204 GMM(const GMM& other); 00205 00209 template<typename OtherFittingType> 00210 GMM& operator=(const GMM<OtherFittingType>& other); 00211 00216 GMM& operator=(const GMM& other); 00217 00219 size_t Gaussians() const { return gaussians; } 00222 size_t& Gaussians() { return gaussians; } 00223 00225 size_t Dimensionality() const { return dimensionality; } 00228 size_t& Dimensionality() { return dimensionality; } 00229 00231 const std::vector<arma::vec>& Means() const { return means; } 00233 std::vector<arma::vec>& Means() { return means; } 00234 00236 const std::vector<arma::mat>& Covariances() const { return covariances; } 00238 std::vector<arma::mat>& Covariances() { return covariances; } 00239 00241 const arma::vec& Weights() const { return weights; } 00243 arma::vec& Weights() { return weights; } 00244 00246 const FittingType& Fitter() const { return fitter; } 00248 FittingType& Fitter() { return fitter; } 00249 00256 double Probability(const arma::vec& observation) const; 00257 00265 double Probability(const arma::vec& observation, 00266 const size_t component) const; 00267 00274 arma::vec Random() const; 00275 00291 double Estimate(const arma::mat& observations, 00292 const size_t trials = 1); 00293 00311 double Estimate(const arma::mat& observations, 00312 const arma::vec& probabilities, 00313 const size_t trials = 1); 00314 00331 void Classify(const arma::mat& observations, 00332 arma::Col<size_t>& labels) const; 00333 00334 private: 00344 double LogLikelihood(const arma::mat& dataPoints, 00345 const std::vector<arma::vec>& means, 00346 const std::vector<arma::mat>& covars, 00347 const arma::vec& weights) const; 00348 00350 FittingType localFitter; 00351 00353 FittingType& fitter; 00354 }; 00355 00356 }; // namespace gmm 00357 }; // namespace mlpack 00358 00359 // Include implementation. 00360 #include "gmm_impl.hpp" 00361 00362 #endif