1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math.stat.regression;
18
19 import static org.junit.Assert.assertEquals;
20
21 import org.apache.commons.math.TestUtils;
22 import org.apache.commons.math.linear.DefaultRealMatrixChangingVisitor;
23 import org.apache.commons.math.linear.MatrixUtils;
24 import org.apache.commons.math.linear.MatrixVisitorException;
25 import org.apache.commons.math.linear.RealMatrix;
26 import org.apache.commons.math.linear.Array2DRowRealMatrix;
27 import org.junit.Before;
28 import org.junit.Test;
29
30 public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
31
32 private double[] y;
33 private double[][] x;
34
35 @Before
36 @Override
37 public void setUp(){
38 y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
39 x = new double[6][];
40 x[0] = new double[]{1.0, 0, 0, 0, 0, 0};
41 x[1] = new double[]{1.0, 2.0, 0, 0, 0, 0};
42 x[2] = new double[]{1.0, 0, 3.0, 0, 0, 0};
43 x[3] = new double[]{1.0, 0, 0, 4.0, 0, 0};
44 x[4] = new double[]{1.0, 0, 0, 0, 5.0, 0};
45 x[5] = new double[]{1.0, 0, 0, 0, 0, 6.0};
46 super.setUp();
47 }
48
49 @Override
50 protected OLSMultipleLinearRegression createRegression() {
51 OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
52 regression.newSampleData(y, x);
53 return regression;
54 }
55
56 @Override
57 protected int getNumberOfRegressors() {
58 return x[0].length;
59 }
60
61 @Override
62 protected int getSampleSize() {
63 return y.length;
64 }
65
66 @Test(expected=IllegalArgumentException.class)
67 public void cannotAddXSampleData() {
68 createRegression().newSampleData(new double[]{}, null);
69 }
70
71 @Test(expected=IllegalArgumentException.class)
72 public void cannotAddNullYSampleData() {
73 createRegression().newSampleData(null, new double[][]{});
74 }
75
76 @Test(expected=IllegalArgumentException.class)
77 public void cannotAddSampleDataWithSizeMismatch() {
78 double[] y = new double[]{1.0, 2.0};
79 double[][] x = new double[1][];
80 x[0] = new double[]{1.0, 0};
81 createRegression().newSampleData(y, x);
82 }
83
84 @Test
85 public void testPerfectFit() {
86 double[] betaHat = regression.estimateRegressionParameters();
87 TestUtils.assertEquals(betaHat,
88 new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
89 1e-14);
90 double[] residuals = regression.estimateResiduals();
91 TestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d},
92 1e-14);
93 RealMatrix errors =
94 new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false);
95 final double[] s = { 1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0 };
96 RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length);
97 referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
98 @Override
99 public double visit(int row, int column, double value)
100 throws MatrixVisitorException {
101 if (row == 0) {
102 return s[column];
103 }
104 double x = s[row] * s[column];
105 return (row == column) ? 2 * x : x;
106 }
107 });
108 assertEquals(0.0,
109 errors.subtract(referenceVariance).getNorm(),
110 5.0e-16 * referenceVariance.getNorm());
111 }
112
113
114
115
116
117
118
119
120
121
122
123
124 @Test
125 public void testLongly() {
126
127
128 double[] design = new double[] {
129 60323,83.0,234289,2356,1590,107608,1947,
130 61122,88.5,259426,2325,1456,108632,1948,
131 60171,88.2,258054,3682,1616,109773,1949,
132 61187,89.5,284599,3351,1650,110929,1950,
133 63221,96.2,328975,2099,3099,112075,1951,
134 63639,98.1,346999,1932,3594,113270,1952,
135 64989,99.0,365385,1870,3547,115094,1953,
136 63761,100.0,363112,3578,3350,116219,1954,
137 66019,101.2,397469,2904,3048,117388,1955,
138 67857,104.6,419180,2822,2857,118734,1956,
139 68169,108.4,442769,2936,2798,120445,1957,
140 66513,110.8,444546,4681,2637,121950,1958,
141 68655,112.6,482704,3813,2552,123366,1959,
142 69564,114.2,502601,3931,2514,125368,1960,
143 69331,115.7,518173,4806,2572,127852,1961,
144 70551,116.9,554894,4007,2827,130081,1962
145 };
146
147
148 int nobs = 16;
149 int nvars = 6;
150
151
152 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
153 model.newSampleData(design, nobs, nvars);
154
155
156 double[] betaHat = model.estimateRegressionParameters();
157 TestUtils.assertEquals(betaHat,
158 new double[]{-3482258.63459582, 15.0618722713733,
159 -0.358191792925910E-01,-2.02022980381683,
160 -1.03322686717359,-0.511041056535807E-01,
161 1829.15146461355}, 2E-8);
162
163
164 double[] residuals = model.estimateResiduals();
165 TestUtils.assertEquals(residuals, new double[]{
166 267.340029759711,-94.0139423988359,46.28716775752924,
167 -410.114621930906,309.7145907602313,-249.3112153297231,
168 -164.0489563956039,-13.18035686637081,14.30477260005235,
169 455.394094551857,-17.26892711483297,-39.0550425226967,
170 -155.5499735953195,-85.6713080421283,341.9315139607727,
171 -206.7578251937366},
172 1E-8);
173
174
175 double[] errors = model.estimateRegressionParametersStandardErrors();
176 TestUtils.assertEquals(new double[] {890420.383607373,
177 84.9149257747669,
178 0.334910077722432E-01,
179 0.488399681651699,
180 0.214274163161675,
181 0.226073200069370,
182 455.478499142212}, errors, 1E-6);
183 }
184
185
186
187
188
189 @Test
190 public void testSwissFertility() {
191 double[] design = new double[] {
192 80.2,17.0,15,12,9.96,
193 83.1,45.1,6,9,84.84,
194 92.5,39.7,5,5,93.40,
195 85.8,36.5,12,7,33.77,
196 76.9,43.5,17,15,5.16,
197 76.1,35.3,9,7,90.57,
198 83.8,70.2,16,7,92.85,
199 92.4,67.8,14,8,97.16,
200 82.4,53.3,12,7,97.67,
201 82.9,45.2,16,13,91.38,
202 87.1,64.5,14,6,98.61,
203 64.1,62.0,21,12,8.52,
204 66.9,67.5,14,7,2.27,
205 68.9,60.7,19,12,4.43,
206 61.7,69.3,22,5,2.82,
207 68.3,72.6,18,2,24.20,
208 71.7,34.0,17,8,3.30,
209 55.7,19.4,26,28,12.11,
210 54.3,15.2,31,20,2.15,
211 65.1,73.0,19,9,2.84,
212 65.5,59.8,22,10,5.23,
213 65.0,55.1,14,3,4.52,
214 56.6,50.9,22,12,15.14,
215 57.4,54.1,20,6,4.20,
216 72.5,71.2,12,1,2.40,
217 74.2,58.1,14,8,5.23,
218 72.0,63.5,6,3,2.56,
219 60.5,60.8,16,10,7.72,
220 58.3,26.8,25,19,18.46,
221 65.4,49.5,15,8,6.10,
222 75.5,85.9,3,2,99.71,
223 69.3,84.9,7,6,99.68,
224 77.3,89.7,5,2,100.00,
225 70.5,78.2,12,6,98.96,
226 79.4,64.9,7,3,98.22,
227 65.0,75.9,9,9,99.06,
228 92.2,84.6,3,3,99.46,
229 79.3,63.1,13,13,96.83,
230 70.4,38.4,26,12,5.62,
231 65.7,7.7,29,11,13.79,
232 72.7,16.7,22,13,11.22,
233 64.4,17.6,35,32,16.92,
234 77.6,37.6,15,7,4.97,
235 67.6,18.7,25,7,8.65,
236 35.0,1.2,37,53,42.34,
237 44.7,46.6,16,29,50.43,
238 42.8,27.7,22,29,58.33
239 };
240
241
242 int nobs = 47;
243 int nvars = 4;
244
245
246 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
247 model.newSampleData(design, nobs, nvars);
248
249
250 double[] betaHat = model.estimateRegressionParameters();
251 TestUtils.assertEquals(betaHat,
252 new double[]{91.05542390271397,
253 -0.22064551045715,
254 -0.26058239824328,
255 -0.96161238456030,
256 0.12441843147162}, 1E-12);
257
258
259 double[] residuals = model.estimateResiduals();
260 TestUtils.assertEquals(residuals, new double[]{
261 7.1044267859730512,1.6580347433531366,
262 4.6944952770029644,8.4548022690166160,13.6547432343186212,
263 -9.3586864458500774,7.5822446330520386,15.5568995563859289,
264 0.8113090736598980,7.1186762732484308,7.4251378771228724,
265 2.6761316873234109,0.8351584810309354,7.1769991119615177,
266 -3.8746753206299553,-3.1337779476387251,-0.1412575244091504,
267 1.1186809170469780,-6.3588097346816594,3.4039270429434074,
268 2.3374058329820175,-7.9272368576900503,-7.8361010968497959,
269 -11.2597369269357070,0.9445333697827101,6.6544245101380328,
270 -0.9146136301118665,-4.3152449403848570,-4.3536932047009183,
271 -3.8907885169304661,-6.3027643926302188,-7.8308982189289091,
272 -3.1792280015332750,-6.7167298771158226,-4.8469946718041754,
273 -10.6335664353633685,11.1031134362036958,6.0084032641811733,
274 5.4326230830188482,-7.2375578629692230,2.1671550814448222,
275 15.0147574652763112,4.8625103516321015,-7.1597256413907706,
276 -0.4515205619767598,-10.2916870903837587,-15.7812984571900063},
277 1E-12);
278
279
280 double[] errors = model.estimateRegressionParametersStandardErrors();
281 TestUtils.assertEquals(new double[] {6.94881329475087,
282 0.07360008972340,
283 0.27410957467466,
284 0.19454551679325,
285 0.03726654773803}, errors, 1E-10);
286 }
287
288
289
290
291
292
293 @Test
294 public void testHat() throws Exception {
295
296
297
298
299
300
301
302 double[] design = new double[] {
303 11.14, .499, 11.1,
304 12.74, .558, 8.9,
305 13.13, .604, 8.8,
306 11.51, .441, 8.9,
307 12.38, .550, 8.8,
308 12.60, .528, 9.9,
309 11.13, .418, 10.7,
310 11.7, .480, 10.5,
311 11.02, .406, 10.5,
312 11.41, .467, 10.7
313 };
314
315 int nobs = 10;
316 int nvars = 2;
317
318
319 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
320 model.newSampleData(design, nobs, nvars);
321
322 RealMatrix hat = model.calculateHat();
323
324
325 double[] referenceData = new double[] {
326 .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
327 .242, .292, .136, .243, .128, -.041, .033, -.035, .004,
328 .417, -.019, .273, .187, -.126, .044, -.153, .004,
329 .604, .197, -.038, .168, -.022, .275, -.028,
330 .252, .111, -.030, .019, -.010, -.010,
331 .148, .042, .117, .012, .111,
332 .262, .145, .277, .174,
333 .154, .120, .168,
334 .315, .148,
335 .187
336 };
337
338
339 int k = 0;
340 for (int i = 0; i < 10; i++) {
341 for (int j = i; j < 10; j++) {
342 assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3);
343 assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12);
344 k++;
345 }
346 }
347
348
349
350
351
352 double[] residuals = model.estimateResiduals();
353 RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
354 double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
355 TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
356 }
357 }