00001 /*========================================================================= 00002 00003 Program: Insight Segmentation & Registration Toolkit 00004 Module: $RCSfile: itkMultilayerNeuralNetworkBase.h,v $ 00005 Language: C++ 00006 Date: $Date: 2006/04/17 21:34:31 $ 00007 Version: $Revision: 1.4 $ 00008 00009 Copyright (c) Insight Software Consortium. All rights reserved. 00010 See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details. 00011 00012 This software is distributed WITHOUT ANY WARRANTY; without even 00013 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 00014 PURPOSE. See the above copyright notices for more information. 00015 00016 =========================================================================*/ 00017 #ifndef __MultiLayerNeuralNetworkBase_h 00018 #define __MultiLayerNeuralNetworkBase_h 00019 00020 #include "itkNeuralNetworkObject.h" 00021 #include "itkErrorBackPropagationLearningFunctionBase.h" 00022 #include "itkErrorBackPropagationLearningWithMomentum.h" 00023 #include "itkQuickPropLearningRule.h" 00024 00025 namespace itk 00026 { 00027 namespace Statistics 00028 { 00029 00030 template<class TVector, class TOutput> 00031 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TVector, TOutput> 00032 { 00033 public: 00034 00035 typedef MultilayerNeuralNetworkBase Self; 00036 typedef NeuralNetworkObject<TVector, TOutput> Superclass; 00037 typedef SmartPointer<Self> Pointer; 00038 typedef SmartPointer<const Self> ConstPointer; 00039 itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject); 00040 00042 itkNewMacro( Self ); 00043 00044 typedef typename Superclass::ValueType ValueType; 00045 typedef typename Superclass::NetworkOutputType NetworkOutputType; 00046 typedef typename Superclass::LayerType LayerType; 00047 typedef typename Superclass::WeightSetType WeightSetType; 00048 typedef typename Superclass::WeightSetPointer WeightSetPointer; 00049 typedef typename Superclass::LayerPointer LayerPointer; 00050 typedef typename Superclass::LearningFunctionType LearningFunctionType; 00051 typedef typename Superclass::LearningFunctionPointer LearningFunctionPointer; 00052 00053 typedef std::vector<WeightSetPointer> WeightVectorType; 00054 typedef std::vector<LayerPointer> LayerVectorType; 00055 00056 itkSetMacro(NumOfLayers, int); 00057 itkGetConstReferenceMacro(NumOfLayers, int); 00058 00059 itkSetMacro(NumOfWeightSets, int); 00060 itkGetConstReferenceMacro(NumOfWeightSets, int); 00061 00062 void AddLayer(LayerType*); 00063 LayerType* GetLayer(int layer_id); 00064 00065 void AddWeightSet(WeightSetType*); 00066 WeightSetType* GetWeightSet(int id); 00067 00068 void SetLearningFunction(LearningFunctionType* f); 00069 00070 // virtual ValueType* GenerateOutput(TVector samplevector); 00071 virtual NetworkOutputType GenerateOutput(TVector samplevector); 00072 00073 // virtual void BackwardPropagate(TOutput errors); 00074 virtual void BackwardPropagate(NetworkOutputType errors); 00075 00076 virtual void UpdateWeights(ValueType); 00077 00078 void SetLearningRule(LearningFunctionType*); 00079 00080 void SetLearningRate(ValueType learningrate); 00081 00082 void InitializeWeights(); 00083 00084 protected: 00085 MultilayerNeuralNetworkBase(); 00086 ~MultilayerNeuralNetworkBase(); 00087 00088 LayerVectorType m_Layers; 00089 WeightVectorType m_Weights; 00090 LearningFunctionPointer m_LearningFunction; 00091 ValueType m_LearningRate; 00092 int m_NumOfLayers; 00093 int m_NumOfWeightSets; 00095 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00096 }; 00097 00098 } // end namespace Statistics 00099 } // end namespace itk 00100 00101 #ifndef ITK_MANUAL_INSTANTIATION 00102 #include "itkMultilayerNeuralNetworkBase.txx" 00103 #endif 00104 00105 #endif 00106