001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.regression;
018    
019    import static org.junit.Assert.assertEquals;
020    
021    import org.apache.commons.math.TestUtils;
022    import org.apache.commons.math.linear.DefaultRealMatrixChangingVisitor;
023    import org.apache.commons.math.linear.MatrixUtils;
024    import org.apache.commons.math.linear.MatrixVisitorException;
025    import org.apache.commons.math.linear.RealMatrix;
026    import org.apache.commons.math.linear.Array2DRowRealMatrix;
027    import org.junit.Before;
028    import org.junit.Test;
029    
030    public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
031    
032        private double[] y;
033        private double[][] x;
034        
035        @Before
036        @Override
037        public void setUp(){
038            y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
039            x = new double[6][];
040            x[0] = new double[]{1.0, 0, 0, 0, 0, 0};
041            x[1] = new double[]{1.0, 2.0, 0, 0, 0, 0};
042            x[2] = new double[]{1.0, 0, 3.0, 0, 0, 0};
043            x[3] = new double[]{1.0, 0, 0, 4.0, 0, 0};
044            x[4] = new double[]{1.0, 0, 0, 0, 5.0, 0};
045            x[5] = new double[]{1.0, 0, 0, 0, 0, 6.0};
046            super.setUp();
047        }
048    
049        @Override
050        protected OLSMultipleLinearRegression createRegression() {
051            OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
052            regression.newSampleData(y, x);
053            return regression;
054        }
055    
056        @Override
057        protected int getNumberOfRegressors() {
058            return x[0].length;
059        }
060    
061        @Override
062        protected int getSampleSize() {
063            return y.length;
064        }
065        
066        @Test(expected=IllegalArgumentException.class)
067        public void cannotAddXSampleData() {
068            createRegression().newSampleData(new double[]{}, null);
069        }
070    
071        @Test(expected=IllegalArgumentException.class)
072        public void cannotAddNullYSampleData() {
073            createRegression().newSampleData(null, new double[][]{});
074        }
075        
076        @Test(expected=IllegalArgumentException.class)
077        public void cannotAddSampleDataWithSizeMismatch() {
078            double[] y = new double[]{1.0, 2.0};
079            double[][] x = new double[1][];
080            x[0] = new double[]{1.0, 0};
081            createRegression().newSampleData(y, x);
082        }
083        
084        @Test
085        public void testPerfectFit() {
086            double[] betaHat = regression.estimateRegressionParameters();
087            TestUtils.assertEquals(betaHat, 
088                                   new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
089                                   1e-14);
090            double[] residuals = regression.estimateResiduals();
091            TestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d},
092                                   1e-14);
093            RealMatrix errors =
094                new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false);
095            final double[] s = { 1.0, -1.0 /  2.0, -1.0 /  3.0, -1.0 /  4.0, -1.0 /  5.0, -1.0 /  6.0 };
096            RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length);
097            referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
098                @Override
099                public double visit(int row, int column, double value)
100                    throws MatrixVisitorException {
101                    if (row == 0) {
102                        return s[column];
103                    }
104                    double x = s[row] * s[column];
105                    return (row == column) ? 2 * x : x;
106                }
107            });
108           assertEquals(0.0,
109                         errors.subtract(referenceVariance).getNorm(),
110                         5.0e-16 * referenceVariance.getNorm());
111        }
112        
113        
114        /**
115         * Test Longley dataset against certified values provided by NIST.
116         * Data Source: J. Longley (1967) "An Appraisal of Least Squares
117         * Programs for the Electronic Computer from the Point of View of the User"
118         * Journal of the American Statistical Association, vol. 62. September,
119         * pp. 819-841.
120         * 
121         * Certified values (and data) are from NIST:
122         * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
123         */
124        @Test
125        public void testLongly() {
126            // Y values are first, then independent vars
127            // Each row is one observation
128            double[] design = new double[] {
129                60323,83.0,234289,2356,1590,107608,1947,
130                61122,88.5,259426,2325,1456,108632,1948,
131                60171,88.2,258054,3682,1616,109773,1949,
132                61187,89.5,284599,3351,1650,110929,1950,
133                63221,96.2,328975,2099,3099,112075,1951,
134                63639,98.1,346999,1932,3594,113270,1952,
135                64989,99.0,365385,1870,3547,115094,1953,
136                63761,100.0,363112,3578,3350,116219,1954,
137                66019,101.2,397469,2904,3048,117388,1955,
138                67857,104.6,419180,2822,2857,118734,1956,
139                68169,108.4,442769,2936,2798,120445,1957,
140                66513,110.8,444546,4681,2637,121950,1958,
141                68655,112.6,482704,3813,2552,123366,1959,
142                69564,114.2,502601,3931,2514,125368,1960,
143                69331,115.7,518173,4806,2572,127852,1961,
144                70551,116.9,554894,4007,2827,130081,1962
145            };
146            
147            // Transform to Y and X required by interface
148            int nobs = 16;
149            int nvars = 6;
150            
151            // Estimate the model
152            OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
153            model.newSampleData(design, nobs, nvars);
154            
155            // Check expected beta values from NIST
156            double[] betaHat = model.estimateRegressionParameters();
157            TestUtils.assertEquals(betaHat, 
158              new double[]{-3482258.63459582, 15.0618722713733,
159                    -0.358191792925910E-01,-2.02022980381683,
160                    -1.03322686717359,-0.511041056535807E-01,
161                     1829.15146461355}, 2E-8); // 
162            
163            // Check expected residuals from R
164            double[] residuals = model.estimateResiduals();
165            TestUtils.assertEquals(residuals, new double[]{
166                    267.340029759711,-94.0139423988359,46.28716775752924,
167                    -410.114621930906,309.7145907602313,-249.3112153297231,
168                    -164.0489563956039,-13.18035686637081,14.30477260005235,
169                     455.394094551857,-17.26892711483297,-39.0550425226967,
170                    -155.5499735953195,-85.6713080421283,341.9315139607727,
171                    -206.7578251937366},
172                          1E-8);
173            
174            // Check standard errors from NIST
175            double[] errors = model.estimateRegressionParametersStandardErrors();
176            TestUtils.assertEquals(new double[] {890420.383607373,
177                           84.9149257747669,
178                           0.334910077722432E-01,
179                           0.488399681651699,
180                           0.214274163161675,
181                           0.226073200069370,
182                           455.478499142212}, errors, 1E-6); 
183        }
184        
185        /**
186         * Test R Swiss fertility dataset against R.
187         * Data Source: R datasets package
188         */
189        @Test
190        public void testSwissFertility() {
191            double[] design = new double[] {
192                80.2,17.0,15,12,9.96,
193                83.1,45.1,6,9,84.84,
194                92.5,39.7,5,5,93.40,
195                85.8,36.5,12,7,33.77,
196                76.9,43.5,17,15,5.16,
197                76.1,35.3,9,7,90.57,
198                83.8,70.2,16,7,92.85,
199                92.4,67.8,14,8,97.16,
200                82.4,53.3,12,7,97.67,
201                82.9,45.2,16,13,91.38,
202                87.1,64.5,14,6,98.61,
203                64.1,62.0,21,12,8.52,
204                66.9,67.5,14,7,2.27,
205                68.9,60.7,19,12,4.43,
206                61.7,69.3,22,5,2.82,
207                68.3,72.6,18,2,24.20,
208                71.7,34.0,17,8,3.30,
209                55.7,19.4,26,28,12.11,
210                54.3,15.2,31,20,2.15,
211                65.1,73.0,19,9,2.84,
212                65.5,59.8,22,10,5.23,
213                65.0,55.1,14,3,4.52,
214                56.6,50.9,22,12,15.14,
215                57.4,54.1,20,6,4.20,
216                72.5,71.2,12,1,2.40,
217                74.2,58.1,14,8,5.23,
218                72.0,63.5,6,3,2.56,
219                60.5,60.8,16,10,7.72,
220                58.3,26.8,25,19,18.46,
221                65.4,49.5,15,8,6.10,
222                75.5,85.9,3,2,99.71,
223                69.3,84.9,7,6,99.68,
224                77.3,89.7,5,2,100.00,
225                70.5,78.2,12,6,98.96,
226                79.4,64.9,7,3,98.22,
227                65.0,75.9,9,9,99.06,
228                92.2,84.6,3,3,99.46,
229                79.3,63.1,13,13,96.83,
230                70.4,38.4,26,12,5.62,
231                65.7,7.7,29,11,13.79,
232                72.7,16.7,22,13,11.22,
233                64.4,17.6,35,32,16.92,
234                77.6,37.6,15,7,4.97,
235                67.6,18.7,25,7,8.65,
236                35.0,1.2,37,53,42.34,
237                44.7,46.6,16,29,50.43,
238                42.8,27.7,22,29,58.33
239            };
240    
241            // Transform to Y and X required by interface
242            int nobs = 47;
243            int nvars = 4;
244    
245            // Estimate the model
246            OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
247            model.newSampleData(design, nobs, nvars);
248    
249            // Check expected beta values from R
250            double[] betaHat = model.estimateRegressionParameters();
251            TestUtils.assertEquals(betaHat, 
252                    new double[]{91.05542390271397,
253                    -0.22064551045715,
254                    -0.26058239824328,
255                    -0.96161238456030,
256                     0.12441843147162}, 1E-12);
257    
258            // Check expected residuals from R
259            double[] residuals = model.estimateResiduals();
260            TestUtils.assertEquals(residuals, new double[]{
261                    7.1044267859730512,1.6580347433531366,
262                    4.6944952770029644,8.4548022690166160,13.6547432343186212,
263                   -9.3586864458500774,7.5822446330520386,15.5568995563859289,
264                    0.8113090736598980,7.1186762732484308,7.4251378771228724,
265                    2.6761316873234109,0.8351584810309354,7.1769991119615177,
266                   -3.8746753206299553,-3.1337779476387251,-0.1412575244091504,
267                    1.1186809170469780,-6.3588097346816594,3.4039270429434074,
268                    2.3374058329820175,-7.9272368576900503,-7.8361010968497959,
269                   -11.2597369269357070,0.9445333697827101,6.6544245101380328,
270                   -0.9146136301118665,-4.3152449403848570,-4.3536932047009183,
271                   -3.8907885169304661,-6.3027643926302188,-7.8308982189289091,
272                   -3.1792280015332750,-6.7167298771158226,-4.8469946718041754,
273                   -10.6335664353633685,11.1031134362036958,6.0084032641811733,
274                    5.4326230830188482,-7.2375578629692230,2.1671550814448222,
275                    15.0147574652763112,4.8625103516321015,-7.1597256413907706,
276                    -0.4515205619767598,-10.2916870903837587,-15.7812984571900063},
277                    1E-12); 
278            
279            // Check standard errors from R
280            double[] errors = model.estimateRegressionParametersStandardErrors();
281            TestUtils.assertEquals(new double[] {6.94881329475087,
282                    0.07360008972340,
283                    0.27410957467466,
284                    0.19454551679325,
285                    0.03726654773803}, errors, 1E-10); 
286        }
287        
288        /**
289         * Test hat matrix computation
290         * 
291         * @throws Exception
292         */
293        @Test
294        public void testHat() throws Exception {
295            
296            /*
297             * This example is from "The Hat Matrix in Regression and ANOVA", 
298             * David C. Hoaglin and Roy E. Welsch, 
299             * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
300             * 
301             */
302            double[] design = new double[] {
303                    11.14, .499, 11.1,
304                    12.74, .558, 8.9,
305                    13.13, .604, 8.8,
306                    11.51, .441, 8.9,
307                    12.38, .550, 8.8,
308                    12.60, .528, 9.9,
309                    11.13, .418, 10.7,
310                    11.7, .480, 10.5,
311                    11.02, .406, 10.5,
312                    11.41, .467, 10.7
313            };
314            
315            int nobs = 10;
316            int nvars = 2;
317            
318            // Estimate the model
319            OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
320            model.newSampleData(design, nobs, nvars);
321            
322            RealMatrix hat = model.calculateHat();
323            
324            // Reference data is upper half of symmetric hat matrix
325            double[] referenceData = new double[] {
326                    .418, -.002,  .079, -.274, -.046,  .181,  .128,  .222,  .050,  .242,
327                           .242,  .292,  .136,  .243,  .128, -.041,  .033, -.035,  .004,
328                                  .417, -.019,  .273,  .187, -.126,  .044, -.153,  .004,
329                                         .604,  .197, -.038,  .168, -.022,  .275, -.028,
330                                                .252,  .111, -.030,  .019, -.010, -.010,
331                                                       .148,  .042,  .117,  .012,  .111,
332                                                              .262,  .145,  .277,  .174,
333                                                                     .154,  .120,  .168,
334                                                                            .315,  .148,
335                                                                                   .187
336            };
337            
338            // Check against reference data and verify symmetry
339            int k = 0;
340            for (int i = 0; i < 10; i++) {
341                for (int j = i; j < 10; j++) {
342                    assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3);
343                    assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12);
344                    k++;  
345                }
346            }
347            
348            /* 
349             * Verify that residuals computed using the hat matrix are close to 
350             * what we get from direct computation, i.e. r = (I - H) y
351             */
352            double[] residuals = model.estimateResiduals();
353            RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
354            double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
355            TestUtils.assertEquals(residuals, hatResiduals, 10e-12);    
356        }
357    }