00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTreeBasedKmeansEstimator_h
00018 #define __itkKdTreeBasedKmeansEstimator_h
00019
00020 #include <vector>
00021 #include "itk_hash_map.h"
00022
00023 #include "itkObject.h"
00024
00025 namespace itk {
00026 namespace Statistics {
00027
00059 template< class TKdTree >
00060 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00061 public Object
00062 {
00063 public:
00065 typedef KdTreeBasedKmeansEstimator Self ;
00066 typedef Object Superclass;
00067 typedef SmartPointer<Self> Pointer;
00068 typedef SmartPointer<const Self> ConstPointer;
00069
00071 itkNewMacro(Self);
00072
00074 itkTypeMacro(KdTreeBasedKmeansEstimator, Obeject);
00075
00077 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType ;
00078 typedef typename TKdTree::MeasurementType MeasurementType ;
00079 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType ;
00080 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier ;
00081 typedef typename TKdTree::SampleType SampleType ;
00082 typedef typename KdTreeNodeType::CentroidType CentroidType ;
00083 itkStaticConstMacro(MeasurementVectorSize, unsigned int,
00084 TKdTree::MeasurementVectorSize);
00087 typedef FixedArray< double, itkGetStaticConstMacro(MeasurementVectorSize) > ParameterType ;
00088 typedef std::vector< ParameterType > InternalParametersType;
00089 typedef Array< double > ParametersType;
00090
00092 void SetParameters(ParametersType& params)
00093 { m_Parameters = params ; }
00094
00096 ParametersType& GetParameters()
00097 { return m_Parameters ; }
00098
00100 itkSetMacro( MaximumIteration, int );
00101 itkGetConstReferenceMacro( MaximumIteration, int );
00102
00105 itkSetMacro( CentroidPositionChangesThreshold, double );
00106 itkGetConstReferenceMacro( CentroidPositionChangesThreshold, double );
00107
00109 void SetKdTree(TKdTree* tree)
00110 { m_KdTree = tree ; }
00111
00112 TKdTree* GetKdTree()
00113 { return m_KdTree.GetPointer() ; }
00114
00115 itkGetConstReferenceMacro( CurrentIteration, int) ;
00116 itkGetConstReferenceMacro( CentroidPositionChanges, double) ;
00117
00122 void StartOptimization() ;
00123
00124 typedef itk::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType ;
00125
00126 void SetUseClusterLabels(bool flag)
00127 { m_UseClusterLabels = flag ; }
00128
00129 ClusterLabelsType* GetClusterLabels()
00130 { return &m_ClusterLabels ; }
00131
00132 protected:
00133 KdTreeBasedKmeansEstimator() ;
00134 virtual ~KdTreeBasedKmeansEstimator() {}
00135
00136 void PrintSelf(std::ostream& os, Indent indent) const;
00137
00138 void FillClusterLabels(KdTreeNodeType* node, int closestIndex) ;
00139
00141 class CandidateVector
00142 {
00143 public:
00144 CandidateVector() {}
00145
00146 struct Candidate
00147 {
00148 CentroidType Centroid ;
00149 CentroidType WeightedCentroid ;
00150 int Size ;
00151 } ;
00152
00153 virtual ~CandidateVector() {}
00154
00156 int Size() const
00157 { return static_cast<int>( m_Candidates.size() ); }
00158
00161 void SetCentroids(InternalParametersType& centroids)
00162 {
00163 m_Candidates.resize(centroids.size()) ;
00164 for (unsigned int i = 0 ; i < centroids.size() ; i++)
00165 {
00166 Candidate candidate ;
00167 candidate.Centroid = centroids[i] ;
00168 candidate.WeightedCentroid.Fill(0.0) ;
00169 candidate.Size = 0 ;
00170 m_Candidates[i] = candidate ;
00171 }
00172 }
00173
00175 void GetCentroids(InternalParametersType& centroids)
00176 {
00177 unsigned int i ;
00178 centroids.resize(this->Size()) ;
00179 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00180 {
00181 centroids[i] = m_Candidates[i].Centroid ;
00182 }
00183 }
00184
00187 void UpdateCentroids()
00188 {
00189 unsigned int i, j ;
00190 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00191 {
00192 if (m_Candidates[i].Size > 0)
00193 {
00194 for (j = 0 ; j < MeasurementVectorSize ; j++)
00195 {
00196 m_Candidates[i].Centroid[j] =
00197 m_Candidates[i].WeightedCentroid[j] /
00198 double(m_Candidates[i].Size) ;
00199 }
00200 }
00201 }
00202 }
00203
00205 Candidate& operator[](int index)
00206 { return m_Candidates[index] ; }
00207
00208
00209 private:
00211 std::vector< Candidate > m_Candidates ;
00212 } ;
00213
00219 double GetSumOfSquaredPositionChanges(InternalParametersType &previous,
00220 InternalParametersType ¤t) ;
00221
00224 int GetClosestCandidate(ParameterType &measurements,
00225 std::vector< int > &validIndexes) ;
00226
00228 bool IsFarther(ParameterType &pointA,
00229 ParameterType &pointB,
00230 MeasurementVectorType &lowerBound,
00231 MeasurementVectorType &upperBound) ;
00232
00235 void Filter(KdTreeNodeType* node,
00236 std::vector< int > validIndexes,
00237 MeasurementVectorType &lowerBound,
00238 MeasurementVectorType &upperBound) ;
00239
00241 void CopyParameters(InternalParametersType &source, InternalParametersType &target) ;
00242
00244 void CopyParameters(ParametersType &source, InternalParametersType &target) ;
00245
00247 void CopyParameters(InternalParametersType &source, ParametersType &target) ;
00248
00250 void GetPoint(ParameterType &point,
00251 MeasurementVectorType measurements)
00252 {
00253 for (unsigned int i = 0 ; i < MeasurementVectorSize ; i++)
00254 {
00255 point[i] = measurements[i] ;
00256 }
00257 }
00258
00259 void PrintPoint(ParameterType &point)
00260 {
00261 std::cout << "[ " ;
00262 for (unsigned int i = 0 ; i < MeasurementVectorSize ; i++)
00263 {
00264 std::cout << point[i] << " " ;
00265 }
00266 std::cout << "]" ;
00267 }
00268
00269 private:
00271 int m_CurrentIteration ;
00273 int m_MaximumIteration ;
00275 double m_CentroidPositionChanges ;
00278 double m_CentroidPositionChangesThreshold ;
00280 typename TKdTree::Pointer m_KdTree ;
00282 typename EuclideanDistance< ParameterType >::Pointer m_DistanceMetric ;
00283
00285 ParametersType m_Parameters ;
00286
00287 CandidateVector m_CandidateVector ;
00288
00289 ParameterType m_TempVertex ;
00290
00291 bool m_UseClusterLabels ;
00292 bool m_GenerateClusterLabels ;
00293 ClusterLabelsType m_ClusterLabels ;
00294 } ;
00295
00296 }
00297 }
00298
00299 #ifndef ITK_MANUAL_INSTANTIATION
00300 #include "itkKdTreeBasedKmeansEstimator.txx"
00301 #endif
00302
00303
00304 #endif