|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.joone.engine.RTRL
public class RTRL
A RTRL implementation. Based mostly on http://www.willamette.edu/~gorr/classes/cs449 and a few others. A partial RTRL implementation. Network weights are optimised using an offline RTRL implementation. The initial states of context nodes are not optimised, but could easily be added. For now, initial states are simply assumed to be what they are set to be in the context layer itself. RTRL does not rely on a backpropagated error and this can and should be turned off, in order to speed things up. Functionality for this is included in the Monitor class and turned off whenever the setMonitor message is called. In order to speed things up, this includes an experimental lineseek approach where firstly the gradient is calculated using the offline RTRL algorithm. Then a step is taken along the gradient for as long as the sum of squared errors decreases in a typical lineseek type fashion. As soon as a step results in an increased sum of squared errors, a step back is taken, typically smaller than the step forward, and the gradient is once again updated. The stepping up and down is scaled in the spirit of the RPROP algorithm, so that the learning rate is adjusted after each cycle. Weights can also be randomised in the spirit of simulated annealing at the end of each cycle. As with the above lineseek approach, see the constructor for more details. These two features were really added to try and speed up convergence - if at all! Their practical benefit remain highly suspect at best. This class has a main method which also serves as a demo of the RTRL. Please refer to that. A suitable net can easily be created using the GUI and then trained using the main method, with a few alterations to the code based on the number of patterns for example, which, amongst others, is currently hard coded. The main method also shows how to save and restore a network trained via RTRL. While this class does implement the Serializable interface, it is highly suspect and not meant to be serialised together with the network. This implementation is highly academic at present. Any good exmaples where this can be applied will be much appreciated. I am still looking for them. The initial conditions as well as the learning rate seems to have such a high impact on the convergence of this as to make it of almost no practical use it seems. Also, strangely, it often seems that a higher rather than lower learning rate is better for convergence. Support for multiprocessors have now been added.
Field Summary | |
---|---|
protected NodesAndWeights |
nodesAndWeights
The network we are training |
protected java.util.List<java.util.List<NodesAndWeights.Node>> |
nodesList
List of list of nodes that will be updated by each processor |
protected double[][] |
p
The p matrix, p [ k ] [ij ] is node k's (in U) derivative with respect to weight ij |
protected int |
patternCount
Pattern counter |
protected int |
processorCount
Number of processors to use, 1 or less on a uniprocessor |
protected double[][] |
updateP
The utility updateP matrix, used when updating the p matrix |
protected java.util.List<java.util.List<NodesAndWeights.Weight>> |
weightsList
List of list of weights that will be updated by each processor |
Constructor Summary | |
---|---|
RTRL(NodesAndWeights nodesAndWeights)
Create a new instance of RTRL |
Method Summary | |
---|---|
int |
getProcessorCount()
Retrieve processor count |
protected void |
init()
Initialise |
void |
printP(java.io.PrintStream out)
Helper to print out the p matrix |
protected void |
resetP()
Reset the p matrix in preparation for the next cycle - called at the end of a cycle |
void |
setProcessorCount(int processorCount)
Set the number of processors to use |
void |
update(double[] error)
Update RTRL Call this with the most recent error pattern as soon as one becomes available. |
void |
updateCycle(double learningRate)
Update the weights Call this once a full set of patterns were presented to the network to update the weights |
protected void |
updateDeltas(double[] error)
Update the weights' deltas. |
protected void |
updateDeltas(double[] error,
java.util.List<NodesAndWeights.Weight> weights)
Update the given weights' deltas. |
protected void |
updateP()
Update the p matrix - called after a pattern has been presented to the network |
protected void |
updateP(java.util.List<NodesAndWeights.Node> nodes)
Update the p matrix for the given list of nodes |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected NodesAndWeights nodesAndWeights
protected double[][] p
protected double[][] updateP
protected int patternCount
protected int processorCount
protected java.util.List<java.util.List<NodesAndWeights.Node>> nodesList
protected java.util.List<java.util.List<NodesAndWeights.Weight>> weightsList
Constructor Detail |
---|
public RTRL(NodesAndWeights nodesAndWeights)
nodesAndWeights
- the network to be optimised's structureMethod Detail |
---|
protected void init()
public void setProcessorCount(int processorCount)
public int getProcessorCount()
protected void updateP(java.util.List<NodesAndWeights.Node> nodes)
protected void updateP()
protected void updateDeltas(double[] error, java.util.List<NodesAndWeights.Weight> weights)
error
- most recently seen error patternprotected void updateDeltas(double[] error)
error
- most recently seen error patternprotected void resetP()
public void update(double[] error)
error
- the most recently seen error patternpublic void updateCycle(double learningRate)
public void printP(java.io.PrintStream out)
out
- stream to which to dump the matrix
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |