org.neuroph.core.learning
Class SupervisedLearning

java.lang.Object
  extended by java.util.Observable
      extended by org.neuroph.core.learning.LearningRule
          extended by org.neuroph.core.learning.IterativeLearning
              extended by org.neuroph.core.learning.SupervisedLearning
All Implemented Interfaces:
java.io.Serializable, java.lang.Runnable
Direct Known Subclasses:
LMS

public abstract class SupervisedLearning
extends IterativeLearning
implements java.io.Serializable

Base class for all supervised learning algorithms. It extends IterativeLearning, and provides general supervised learning principles.

Author:
Zoran Sevarac
See Also:
Serialized Form

Field Summary
protected  double maxError
          Max allowed network error (condition to stop learning)
protected  double totalNetworkError
          Total network error
 
Fields inherited from class org.neuroph.core.learning.IterativeLearning
currentIteration, iterationsLimited, learningRate, maxIterations
 
Fields inherited from class org.neuroph.core.learning.LearningRule
neuralNetwork
 
Constructor Summary
SupervisedLearning()
          Creates new supervised learning rule
SupervisedLearning(NeuralNetwork network)
          Creates new supervised learning rule and sets the neural network to train
 
Method Summary
 void doLearningEpoch(TrainingSet trainingSet)
          This method implements basic logic for one learning epoch for the supervised learning algorithms.
protected  java.util.Vector<java.lang.Double> getPatternError(java.util.Vector<java.lang.Double> output, java.util.Vector<java.lang.Double> desiredOutput)
          Calculates the network error for the current pattern - diference between desired and actual output
 java.lang.Double getTotalNetworkError()
          Returns total network error
 void learn(TrainingSet trainingSet, double maxError)
          Trains network for the specified training set and number of iterations
 void learn(TrainingSet trainingSet, double maxError, int maxIterations)
          Trains network for the specified training set and number of iterations
protected  void learnPattern(SupervisedTrainingElement trainingElement)
          Trains network with the pattern from the specified training element
 void setMaxError(java.lang.Double maxError)
          Sets allowed network error, which indicates when to stopLearning training
protected abstract  void updateNetworkWeights(java.util.Vector<java.lang.Double> patternError)
          This method should implement the weights update procedure
protected abstract  void updateTotalNetworkError(java.util.Vector<java.lang.Double> patternError)
          Subclasses update total network error for each training pattern with this method.
 
Methods inherited from class org.neuroph.core.learning.IterativeLearning
doOneLearningIteration, getCurrentIteration, getLearningRate, isPausedLearning, learn, learn, pause, resume, setLearningRate, setMaxIterations
 
Methods inherited from class org.neuroph.core.learning.LearningRule
getNeuralNetwork, getTrainingSet, isStopped, notifyChange, run, setNeuralNetwork, setTrainingSet, stopLearning
 
Methods inherited from class java.util.Observable
addObserver, clearChanged, countObservers, deleteObserver, deleteObservers, hasChanged, notifyObservers, notifyObservers, setChanged
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

totalNetworkError

protected double totalNetworkError
Total network error


maxError

protected double maxError
Max allowed network error (condition to stop learning)

Constructor Detail

SupervisedLearning

public SupervisedLearning()
Creates new supervised learning rule


SupervisedLearning

public SupervisedLearning(NeuralNetwork network)
Creates new supervised learning rule and sets the neural network to train

Parameters:
network - network to train
Method Detail

learn

public void learn(TrainingSet trainingSet,
                  double maxError)
Trains network for the specified training set and number of iterations

Parameters:
trainingSet - training set to learn
maxIterations - maximum numberof iterations to learn

learn

public void learn(TrainingSet trainingSet,
                  double maxError,
                  int maxIterations)
Trains network for the specified training set and number of iterations

Parameters:
trainingSet - training set to learn
maxIterations - maximum numberof iterations to learn

doLearningEpoch

public void doLearningEpoch(TrainingSet trainingSet)
This method implements basic logic for one learning epoch for the supervised learning algorithms. Epoch is the one pass through the training set. This method iterates through the training set and trains network for each element. It also sets flag if conditions to stop learning has been reached: network error below some allowed value, or maximum iteration count

Specified by:
doLearningEpoch in class IterativeLearning
Parameters:
trainingSet - training set for training network

learnPattern

protected void learnPattern(SupervisedTrainingElement trainingElement)
Trains network with the pattern from the specified training element

Parameters:
trainingElement - supervised training element which contains input and desired output

getPatternError

protected java.util.Vector<java.lang.Double> getPatternError(java.util.Vector<java.lang.Double> output,
                                                             java.util.Vector<java.lang.Double> desiredOutput)
Calculates the network error for the current pattern - diference between desired and actual output

Parameters:
output - actual network output
desiredOutput - desired network output
Returns:
pattern error

setMaxError

public void setMaxError(java.lang.Double maxError)
Sets allowed network error, which indicates when to stopLearning training

Parameters:
maxError - network error

getTotalNetworkError

public java.lang.Double getTotalNetworkError()
Returns total network error

Returns:
total network error

updateTotalNetworkError

protected abstract void updateTotalNetworkError(java.util.Vector<java.lang.Double> patternError)
Subclasses update total network error for each training pattern with this method. Error update formula is learning rule specific.

Parameters:
patternError - pattern error vector

updateNetworkWeights

protected abstract void updateNetworkWeights(java.util.Vector<java.lang.Double> patternError)
This method should implement the weights update procedure

Parameters:
patternError - pattern error vector