00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTree_h
00018 #define __itkKdTree_h
00019
00020 #include <queue>
00021 #include <vector>
00022
00023 #include "itkMacro.h"
00024 #include "itkPoint.h"
00025 #include "itkSize.h"
00026 #include "itkObject.h"
00027 #include "itkNumericTraits.h"
00028 #include "itkArray.h"
00029
00030 #include "itkSample.h"
00031 #include "itkSubsample.h"
00032
00033 #include "itkEuclideanDistance.h"
00034
00035 namespace itk{
00036 namespace Statistics{
00037
00056 template< class TSample >
00057 struct KdTreeNode
00058 {
00060 typedef KdTreeNode< TSample> Self ;
00061
00063 typedef typename TSample::MeasurementType MeasurementType ;
00064
00066 itkStaticConstMacro(MeasurementVectorSize, unsigned int,
00067 TSample::MeasurementVectorSize) ;
00068
00070 typedef FixedArray< double,
00071 itkGetStaticConstMacro(MeasurementVectorSize) > CentroidType ;
00072
00075 typedef typename TSample::InstanceIdentifier InstanceIdentifier ;
00076
00079 virtual bool IsTerminal() const = 0 ;
00080
00086 virtual void GetParameters(unsigned int &partitionDimension,
00087 MeasurementType &partitionValue) const = 0 ;
00088
00090 virtual Self* Left() = 0 ;
00091 virtual const Self* Left() const = 0 ;
00092
00094 virtual Self* Right() = 0 ;
00095 virtual const Self* Right() const = 0 ;
00096
00099 virtual unsigned int Size() const = 0 ;
00100
00102 virtual void GetWeightedCentroid(CentroidType ¢roid) = 0 ;
00103
00105 virtual void GetCentroid(CentroidType ¢roid) = 0 ;
00106
00108 virtual InstanceIdentifier GetInstanceIdentifier(size_t index) const = 0 ;
00109
00111 virtual void AddInstanceIdentifier(InstanceIdentifier id) = 0 ;
00112
00114 virtual ~KdTreeNode() {};
00115 } ;
00116
00128 template< class TSample >
00129 struct KdTreeNonterminalNode: public KdTreeNode< TSample >
00130 {
00131 typedef KdTreeNode< TSample > Superclass ;
00132 typedef typename Superclass::MeasurementType MeasurementType ;
00133 typedef typename Superclass::CentroidType CentroidType ;
00134 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00135
00136 KdTreeNonterminalNode(unsigned int partitionDimension,
00137 MeasurementType partitionValue,
00138 Superclass* left,
00139 Superclass* right) ;
00140
00141 virtual ~KdTreeNonterminalNode() {}
00142
00143 virtual bool IsTerminal() const
00144 { return false ; }
00145
00146 void GetParameters(unsigned int &partitionDimension,
00147 MeasurementType &partitionValue) const;
00148
00149 Superclass* Left()
00150 { return m_Left ; }
00151
00152 Superclass* Right()
00153 { return m_Right ; }
00154
00155 const Superclass* Left() const
00156 { return m_Left ; }
00157
00158 const Superclass* Right() const
00159 { return m_Right ; }
00160
00161 unsigned int Size() const
00162 { return 0 ; }
00163
00164 void GetWeightedCentroid(CentroidType &)
00165 { }
00166
00167 void GetCentroid(CentroidType &)
00168 { }
00169
00170 InstanceIdentifier GetInstanceIdentifier(size_t) const
00171 { return 0 ; }
00172
00173 void AddInstanceIdentifier(InstanceIdentifier) {}
00174
00175 private:
00176 unsigned int m_PartitionDimension ;
00177 MeasurementType m_PartitionValue ;
00178 Superclass* m_Left ;
00179 Superclass* m_Right ;
00180 } ;
00181
00196 template< class TSample >
00197 struct KdTreeWeightedCentroidNonterminalNode: public KdTreeNode< TSample >
00198 {
00199 typedef KdTreeNode< TSample > Superclass ;
00200 typedef typename Superclass::MeasurementType MeasurementType ;
00201 typedef typename Superclass::CentroidType CentroidType ;
00202 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00203
00204 KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,
00205 MeasurementType partitionValue,
00206 Superclass* left,
00207 Superclass* right,
00208 CentroidType ¢roid,
00209 unsigned int size) ;
00210 virtual ~KdTreeWeightedCentroidNonterminalNode() {}
00211
00212 virtual bool IsTerminal() const
00213 { return false ; }
00214
00215 void GetParameters(unsigned int &partitionDimension,
00216 MeasurementType &partitionValue) const ;
00217
00218 Superclass* Left()
00219 { return m_Left ; }
00220
00221 Superclass* Right()
00222 { return m_Right ; }
00223
00224
00225 const Superclass* Left() const
00226 { return m_Left ; }
00227
00228 const Superclass* Right() const
00229 { return m_Right ; }
00230
00231 unsigned int Size() const
00232 { return m_Size ; }
00233
00234 void GetWeightedCentroid(CentroidType ¢roid)
00235 { centroid = m_WeightedCentroid ; }
00236
00237 void GetCentroid(CentroidType ¢roid)
00238 { centroid = m_Centroid ; }
00239
00240 InstanceIdentifier GetInstanceIdentifier(size_t) const
00241 { return 0 ; }
00242
00243 void AddInstanceIdentifier(InstanceIdentifier) {}
00244
00245 private:
00246 unsigned int m_PartitionDimension ;
00247 MeasurementType m_PartitionValue ;
00248 CentroidType m_WeightedCentroid ;
00249 CentroidType m_Centroid ;
00250 unsigned int m_Size ;
00251 Superclass* m_Left ;
00252 Superclass* m_Right ;
00253 } ;
00254
00255
00267 template< class TSample >
00268 struct KdTreeTerminalNode: public KdTreeNode< TSample >
00269 {
00270 typedef KdTreeNode< TSample > Superclass ;
00271 typedef typename Superclass::MeasurementType MeasurementType ;
00272 typedef typename Superclass::CentroidType CentroidType ;
00273 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00274
00275 KdTreeTerminalNode() {}
00276
00277 virtual ~KdTreeTerminalNode() {}
00278
00279 bool IsTerminal() const
00280 { return true ; }
00281
00282 void GetParameters(unsigned int &,
00283 MeasurementType &) const {}
00284
00285 Superclass* Left()
00286 { return 0 ; }
00287
00288 Superclass* Right()
00289 { return 0 ; }
00290
00291
00292 const Superclass* Left() const
00293 { return 0 ; }
00294
00295 const Superclass* Right() const
00296 { return 0 ; }
00297
00298 unsigned int Size() const
00299 { return static_cast<unsigned int>( m_InstanceIdentifiers.size() ); }
00300
00301 void GetWeightedCentroid(CentroidType &)
00302 { }
00303
00304 void GetCentroid(CentroidType &)
00305 { }
00306
00307 InstanceIdentifier GetInstanceIdentifier(size_t index) const
00308 { return m_InstanceIdentifiers[index] ; }
00309
00310 void AddInstanceIdentifier(InstanceIdentifier id)
00311 { m_InstanceIdentifiers.push_back(id) ;}
00312
00313 private:
00314 std::vector< InstanceIdentifier > m_InstanceIdentifiers ;
00315 } ;
00316
00343 template < class TSample >
00344 class ITK_EXPORT KdTree : public Object
00345 {
00346 public:
00348 typedef KdTree Self ;
00349 typedef Object Superclass ;
00350 typedef SmartPointer<Self> Pointer;
00351 typedef SmartPointer<const Self> ConstPointer;
00352
00354 itkTypeMacro(KdTree, Object);
00355
00357 itkNewMacro(Self) ;
00358
00360 typedef TSample SampleType ;
00361 typedef typename TSample::MeasurementVectorType MeasurementVectorType ;
00362 typedef typename TSample::MeasurementType MeasurementType ;
00363 typedef typename TSample::InstanceIdentifier InstanceIdentifier ;
00364 typedef typename TSample::FrequencyType FrequencyType ;
00365
00367 itkStaticConstMacro(MeasurementVectorSize, unsigned int,
00368 TSample::MeasurementVectorSize) ;
00369
00371 typedef EuclideanDistance< MeasurementVectorType > DistanceMetricType ;
00372
00374 typedef KdTreeNode< TSample > KdTreeNodeType ;
00375
00379 typedef std::pair< InstanceIdentifier, double > NeighborType ;
00380
00381 typedef std::vector< InstanceIdentifier > InstanceIdentifierVectorType ;
00382
00392 class NearestNeighbors
00393 {
00394 public:
00396 NearestNeighbors() {}
00397
00399 ~NearestNeighbors() {}
00400
00403 void resize(unsigned int k)
00404 {
00405 m_Identifiers.clear() ;
00406 m_Identifiers.resize(k, NumericTraits< unsigned long >::max()) ;
00407 m_Distances.clear() ;
00408 m_Distances.resize(k, NumericTraits< double >::max()) ;
00409 m_FarthestNeighborIndex = 0 ;
00410 }
00411
00413 double GetLargestDistance()
00414 { return m_Distances[m_FarthestNeighborIndex] ; }
00415
00418 void ReplaceFarthestNeighbor(InstanceIdentifier id, double distance)
00419 {
00420 m_Identifiers[m_FarthestNeighborIndex] = id ;
00421 m_Distances[m_FarthestNeighborIndex] = distance ;
00422 double farthestDistance = NumericTraits< double >::min() ;
00423 const unsigned int size = static_cast<unsigned int>( m_Distances.size() );
00424 for ( unsigned int i = 0 ; i < size; i++ )
00425 {
00426 if ( m_Distances[i] > farthestDistance )
00427 {
00428 farthestDistance = m_Distances[i] ;
00429 m_FarthestNeighborIndex = i ;
00430 }
00431 }
00432 }
00433
00435 InstanceIdentifierVectorType GetNeighbors()
00436 { return m_Identifiers ; }
00437
00440 InstanceIdentifier GetNeighbor(unsigned int index)
00441 { return m_Identifiers[index] ; }
00442
00444 std::vector< double >& GetDistances()
00445 { return m_Distances ; }
00446
00447 private:
00449 unsigned int m_FarthestNeighborIndex ;
00450
00452 InstanceIdentifierVectorType m_Identifiers ;
00453
00456 std::vector< double > m_Distances ;
00457 } ;
00458
00461 void SetBucketSize(unsigned int size) ;
00462
00465 void SetSample(const TSample* sample) ;
00466
00468 const TSample* GetSample() const
00469 { return m_Sample ; }
00470
00471 unsigned long Size() const
00472 { return m_Sample->Size() ; }
00473
00478 KdTreeNodeType* GetEmptyTerminalNode()
00479 { return m_EmptyTerminalNode ; }
00480
00483 void SetRoot(KdTreeNodeType* root)
00484 { m_Root = root ; }
00485
00487 KdTreeNodeType* GetRoot()
00488 { return m_Root ; }
00489
00492 const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const
00493 { return m_Sample->GetMeasurementVector(id) ; }
00494
00497 FrequencyType GetFrequency(InstanceIdentifier id) const
00498 { return m_Sample->GetFrequency( id ) ; }
00499
00501 DistanceMetricType* GetDistanceMetric()
00502 { return m_DistanceMetric.GetPointer() ; }
00503
00505 void Search(MeasurementVectorType &query,
00506 unsigned int k,
00507 InstanceIdentifierVectorType& result) const;
00508
00510 void Search(MeasurementVectorType &query,
00511 double radius,
00512 InstanceIdentifierVectorType& result) const;
00513
00516 int GetNumberOfVisits() const
00517 { return m_NumberOfVisits ; }
00518
00524 bool BallWithinBounds(MeasurementVectorType &query,
00525 MeasurementVectorType &lowerBound,
00526 MeasurementVectorType &upperBound,
00527 double radius) const ;
00528
00532 bool BoundsOverlapBall(MeasurementVectorType &query,
00533 MeasurementVectorType &lowerBound,
00534 MeasurementVectorType &upperBound,
00535 double radius) const ;
00536
00538 void DeleteNode(KdTreeNodeType *node) ;
00539
00541 void PrintTree(KdTreeNodeType *node, int level,
00542 unsigned int activeDimension) ;
00543
00544 typedef typename TSample::Iterator Iterator ;
00545 typedef typename TSample::ConstIterator ConstIterator ;
00546
00547 Iterator Begin()
00548 {
00549 typename TSample::ConstIterator iter = m_Sample->Begin() ;
00550 return iter;
00551 }
00552
00553 Iterator End()
00554 {
00555 Iterator iter = m_Sample->End() ;
00556 return iter;
00557 }
00558
00559 ConstIterator Begin() const
00560 {
00561 typename TSample::ConstIterator iter = m_Sample->Begin() ;
00562 return iter;
00563 }
00564
00565 ConstIterator End() const
00566 {
00567 ConstIterator iter = m_Sample->End() ;
00568 return iter;
00569 }
00570
00571
00572 protected:
00574 KdTree() ;
00575
00577 virtual ~KdTree() ;
00578
00579 void PrintSelf(std::ostream& os, Indent indent) const ;
00580
00582 int NearestNeighborSearchLoop(const KdTreeNodeType* node,
00583 MeasurementVectorType &query,
00584 MeasurementVectorType &lowerBound,
00585 MeasurementVectorType &upperBound) const;
00586
00588 int SearchLoop(const KdTreeNodeType* node, MeasurementVectorType &query,
00589 MeasurementVectorType &lowerBound,
00590 MeasurementVectorType &upperBound) const ;
00591 private:
00592 KdTree(const Self&) ;
00593 void operator=(const Self&) ;
00594
00596 const TSample* m_Sample ;
00597
00599 int m_BucketSize ;
00600
00602 KdTreeNodeType* m_Root ;
00603
00605 KdTreeNodeType* m_EmptyTerminalNode ;
00606
00608 typename DistanceMetricType::Pointer m_DistanceMetric ;
00609
00610 mutable bool m_IsNearestNeighborSearch ;
00611
00612 mutable double m_SearchRadius ;
00613
00614 mutable InstanceIdentifierVectorType m_Neighbors ;
00615
00617 mutable NearestNeighbors m_NearestNeighbors ;
00618
00620 mutable MeasurementVectorType m_LowerBound ;
00621
00623 mutable MeasurementVectorType m_UpperBound ;
00624
00626 mutable int m_NumberOfVisits ;
00627
00629 mutable bool m_StopSearch ;
00630
00632 mutable NeighborType m_TempNeighbor ;
00633 } ;
00634
00635 }
00636 }
00637
00638 #ifndef ITK_MANUAL_INSTANTIATION
00639 #include "itkKdTree.txx"
00640 #endif
00641
00642 #endif
00643
00644
00645