001    //Licensed to the Apache Software Foundation (ASF) under one
002    //or more contributor license agreements.  See the NOTICE file
003    //distributed with this work for additional information
004    //regarding copyright ownership.  The ASF licenses this file
005    //to you under the Apache License, Version 2.0 (the
006    //"License"); you may not use this file except in compliance
007    //with the License.  You may obtain a copy of the License at
008    
009    //http://www.apache.org/licenses/LICENSE-2.0
010    
011    //Unless required by applicable law or agreed to in writing,
012    //software distributed under the License is distributed on an
013    //"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
014    //KIND, either express or implied.  See the License for the
015    //specific language governing permissions and limitations
016    //under the License.
017    
018    package org.apache.commons.math.random;
019    
020    import junit.framework.Test;
021    import junit.framework.TestCase;
022    import junit.framework.TestSuite;
023    
024    import org.apache.commons.math.DimensionMismatchException;
025    import org.apache.commons.math.linear.MatrixUtils;
026    import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
027    import org.apache.commons.math.linear.RealMatrix;
028    import org.apache.commons.math.stat.descriptive.moment.VectorialCovariance;
029    import org.apache.commons.math.stat.descriptive.moment.VectorialMean;
030    
031    public class CorrelatedRandomVectorGeneratorTest
032    extends TestCase {
033    
034        public CorrelatedRandomVectorGeneratorTest(String name) {
035            super(name);
036            mean       = null;
037            covariance = null;
038            generator  = null;
039        }
040    
041        public void testRank() {
042            assertEquals(3, generator.getRank());
043        }
044    
045        public void testMath226()
046            throws DimensionMismatchException, NotPositiveDefiniteMatrixException {
047            double[] mean = { 1, 1, 10, 1 };
048            double[][] cov = {
049                    { 1, 3, 2, 6 },
050                    { 3, 13, 16, 2 },
051                    { 2, 16, 38, -1 },
052                    { 6, 2, -1, 197 }
053            };
054            RealMatrix covRM = MatrixUtils.createRealMatrix(cov);
055            JDKRandomGenerator jg = new JDKRandomGenerator();
056            jg.setSeed(5322145245211l);
057            NormalizedRandomGenerator rg = new GaussianRandomGenerator(jg);
058            CorrelatedRandomVectorGenerator sg =
059                new CorrelatedRandomVectorGenerator(mean, covRM, 0.00001, rg);
060    
061            for (int i = 0; i < 10; i++) {
062                double[] generated = sg.nextVector();
063                assertTrue(Math.abs(generated[0] - 1) > 0.1);
064            }
065    
066        }
067    
068        public void testRootMatrix() {
069            RealMatrix b = generator.getRootMatrix();
070            RealMatrix bbt = b.multiply(b.transpose());
071            for (int i = 0; i < covariance.getRowDimension(); ++i) {
072                for (int j = 0; j < covariance.getColumnDimension(); ++j) {
073                    assertEquals(covariance.getEntry(i, j), bbt.getEntry(i, j), 1.0e-12);
074                }
075            }
076        }
077    
078        public void testMeanAndCovariance() throws DimensionMismatchException {
079    
080            VectorialMean meanStat = new VectorialMean(mean.length);
081            VectorialCovariance covStat = new VectorialCovariance(mean.length, true);
082            for (int i = 0; i < 5000; ++i) {
083                double[] v = generator.nextVector();
084                meanStat.increment(v);
085                covStat.increment(v);
086            }
087    
088            double[] estimatedMean = meanStat.getResult();
089            RealMatrix estimatedCovariance = covStat.getResult();
090            for (int i = 0; i < estimatedMean.length; ++i) {
091                assertEquals(mean[i], estimatedMean[i], 0.07);
092                for (int j = 0; j <= i; ++j) {
093                    assertEquals(covariance.getEntry(i, j),
094                            estimatedCovariance.getEntry(i, j),
095                            0.1 * (1.0 + Math.abs(mean[i])) * (1.0 + Math.abs(mean[j])));
096                }
097            }
098    
099        }
100    
101        @Override
102        public void setUp() {
103            try {
104                mean = new double[] { 0.0, 1.0, -3.0, 2.3};
105    
106                RealMatrix b = MatrixUtils.createRealMatrix(4, 3);
107                int counter = 0;
108                for (int i = 0; i < b.getRowDimension(); ++i) {
109                    for (int j = 0; j < b.getColumnDimension(); ++j) {
110                        b.setEntry(i, j, 1.0 + 0.1 * ++counter);
111                    }
112                }
113                RealMatrix bbt = b.multiply(b.transpose());
114                covariance = MatrixUtils.createRealMatrix(mean.length, mean.length);
115                for (int i = 0; i < covariance.getRowDimension(); ++i) {
116                    covariance.setEntry(i, i, bbt.getEntry(i, i));
117                    for (int j = 0; j < covariance.getColumnDimension(); ++j) {
118                        double s = bbt.getEntry(i, j);
119                        covariance.setEntry(i, j, s);
120                        covariance.setEntry(j, i, s);
121                    }
122                }
123    
124                RandomGenerator rg = new JDKRandomGenerator();
125                rg.setSeed(17399225432l);
126                GaussianRandomGenerator rawGenerator = new GaussianRandomGenerator(rg);
127                generator = new CorrelatedRandomVectorGenerator(mean,
128                                                                covariance,
129                                                                1.0e-12 * covariance.getNorm(),
130                                                                rawGenerator);
131            } catch (DimensionMismatchException e) {
132                fail(e.getMessage());
133            } catch (NotPositiveDefiniteMatrixException e) {
134                fail("not positive definite matrix");
135            }
136        }
137    
138        @Override
139        public void tearDown() {
140            mean       = null;
141            covariance = null;
142            generator  = null;
143        }
144    
145        public static Test suite() {
146            return new TestSuite(CorrelatedRandomVectorGeneratorTest.class);
147        }
148    
149        private double[] mean;
150        private RealMatrix covariance;
151        private CorrelatedRandomVectorGenerator generator;
152    
153    }