|
|||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | ||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectjava.util.Observable
org.neuroph.core.learning.LearningRule
org.neuroph.core.learning.IterativeLearning
org.neuroph.core.learning.SupervisedLearning
public abstract class SupervisedLearning
Base class for all supervised learning algorithms. It extends IterativeLearning, and provides general supervised learning principles.
Field Summary | |
---|---|
protected double |
maxError
Max allowed network error (condition to stop learning) |
protected double |
totalNetworkError
Total network error |
Fields inherited from class org.neuroph.core.learning.IterativeLearning |
---|
currentIteration, iterationsLimited, learningRate, maxIterations |
Fields inherited from class org.neuroph.core.learning.LearningRule |
---|
neuralNetwork |
Constructor Summary | |
---|---|
SupervisedLearning()
Creates new supervised learning rule |
|
SupervisedLearning(NeuralNetwork network)
Creates new supervised learning rule and sets the neural network to train |
Method Summary | |
---|---|
void |
doLearningEpoch(TrainingSet trainingSet)
This method implements basic logic for one learning epoch for the supervised learning algorithms. |
protected java.util.Vector<java.lang.Double> |
getPatternError(java.util.Vector<java.lang.Double> output,
java.util.Vector<java.lang.Double> desiredOutput)
Calculates the network error for the current pattern - diference between desired and actual output |
java.lang.Double |
getTotalNetworkError()
Returns total network error |
void |
learn(TrainingSet trainingSet,
double maxError)
Trains network for the specified training set and number of iterations |
void |
learn(TrainingSet trainingSet,
double maxError,
int maxIterations)
Trains network for the specified training set and number of iterations |
protected void |
learnPattern(SupervisedTrainingElement trainingElement)
Trains network with the pattern from the specified training element |
void |
setMaxError(java.lang.Double maxError)
Sets allowed network error, which indicates when to stopLearning training |
protected abstract void |
updateNetworkWeights(java.util.Vector<java.lang.Double> patternError)
This method should implement the weights update procedure |
protected abstract void |
updateTotalNetworkError(java.util.Vector<java.lang.Double> patternError)
Subclasses update total network error for each training pattern with this method. |
Methods inherited from class org.neuroph.core.learning.IterativeLearning |
---|
doOneLearningIteration, getCurrentIteration, getLearningRate, isPausedLearning, learn, learn, pause, resume, setLearningRate, setMaxIterations |
Methods inherited from class org.neuroph.core.learning.LearningRule |
---|
getNeuralNetwork, getTrainingSet, isStopped, notifyChange, run, setNeuralNetwork, setTrainingSet, stopLearning |
Methods inherited from class java.util.Observable |
---|
addObserver, clearChanged, countObservers, deleteObserver, deleteObservers, hasChanged, notifyObservers, notifyObservers, setChanged |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected double totalNetworkError
protected double maxError
Constructor Detail |
---|
public SupervisedLearning()
public SupervisedLearning(NeuralNetwork network)
network
- network to trainMethod Detail |
---|
public void learn(TrainingSet trainingSet, double maxError)
trainingSet
- training set to learnmaxIterations
- maximum numberof iterations to learnpublic void learn(TrainingSet trainingSet, double maxError, int maxIterations)
trainingSet
- training set to learnmaxIterations
- maximum numberof iterations to learnpublic void doLearningEpoch(TrainingSet trainingSet)
doLearningEpoch
in class IterativeLearning
trainingSet
- training set for training networkprotected void learnPattern(SupervisedTrainingElement trainingElement)
trainingElement
- supervised training element which contains input and desired
outputprotected java.util.Vector<java.lang.Double> getPatternError(java.util.Vector<java.lang.Double> output, java.util.Vector<java.lang.Double> desiredOutput)
output
- actual network outputdesiredOutput
- desired network output
public void setMaxError(java.lang.Double maxError)
maxError
- network errorpublic java.lang.Double getTotalNetworkError()
protected abstract void updateTotalNetworkError(java.util.Vector<java.lang.Double> patternError)
patternError
- pattern error vectorprotected abstract void updateNetworkWeights(java.util.Vector<java.lang.Double> patternError)
patternError
- pattern error vector
|
|||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | ||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |