View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math.stat.regression;
18  
19  import org.apache.commons.math.MathRuntimeException;
20  import org.apache.commons.math.linear.LUDecompositionImpl;
21  import org.apache.commons.math.linear.QRDecomposition;
22  import org.apache.commons.math.linear.QRDecompositionImpl;
23  import org.apache.commons.math.linear.RealMatrix;
24  import org.apache.commons.math.linear.Array2DRowRealMatrix;
25  import org.apache.commons.math.linear.RealVector;
26  import org.apache.commons.math.linear.ArrayRealVector;
27  
28  /**
29   * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 
30   * multiple linear regression model.</p>
31   * 
32   * <p>OLS assumes the covariance matrix of the error to be diagonal and with
33   * equal variance.</p>
34   * <p>
35   * u ~ N(0, &sigma;<sup>2</sup>I)
36   * </p>
37   * 
38   * <p>The regression coefficients, b, satisfy the normal equations:
39   * <p>
40   * X<sup>T</sup> X b = X<sup>T</sup> y
41   * </p>
42   * 
43   * <p>To solve the normal equations, this implementation uses QR decomposition
44   * of the X matrix. (See {@link QRDecompositionImpl} for details on the
45   * decomposition algorithm.)
46   * </p>
47   * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/>
48   * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/>
49   * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
50   * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
51   * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
52   * R b = Q<sup>T</sup> y
53   * </p>
54   * Given Q and R, the last equation is solved by back-subsitution.</p>
55   * 
56   * @version $Revision: 783702 $ $Date: 2009-06-11 04:54:02 -0400 (Thu, 11 Jun 2009) $
57   * @since 2.0
58   */
59  public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
60      
61      /** Cached QR decomposition of X matrix */
62      private QRDecomposition qr = null;
63  
64      /**
65       * Loads model x and y sample data, overriding any previous sample.
66       * 
67       * Computes and caches QR decomposition of the X matrix.
68       * @param y the [n,1] array representing the y sample
69       * @param x the [n,k] array representing the x sample
70       * @throws IllegalArgumentException if the x and y array data are not
71       *             compatible for the regression
72       */
73      public void newSampleData(double[] y, double[][] x) {
74          validateSampleData(x, y);
75          newYSampleData(y);
76          newXSampleData(x);
77      }
78      
79      /**
80       * {@inheritDoc}
81       * 
82       * Computes and caches QR decomposition of the X matrix
83       */
84      @Override
85      public void newSampleData(double[] data, int nobs, int nvars) {
86          super.newSampleData(data, nobs, nvars);
87          qr = new QRDecompositionImpl(X);
88      }
89      
90      /**
91       * <p>Compute the "hat" matrix.
92       * </p>
93       * <p>The hat matrix is defined in terms of the design matrix X
94       *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
95       * </p>
96       * <p>The implementation here uses the QR decomposition to compute the
97       * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
98       * p-dimensional identity matrix augmented by 0's.  This computational
99       * formula is from "The Hat Matrix in Regression and ANOVA",
100      * David C. Hoaglin and Roy E. Welsch, 
101      * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
102      * 
103      * @return the hat matrix
104      */
105     public RealMatrix calculateHat() {
106         // Create augmented identity matrix
107         RealMatrix Q = qr.getQ();
108         final int p = qr.getR().getColumnDimension();
109         final int n = Q.getColumnDimension();
110         Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
111         double[][] augIData = augI.getDataRef();
112         for (int i = 0; i < n; i++) {
113             for (int j =0; j < n; j++) {
114                 if (i == j && i < p) {
115                     augIData[i][j] = 1d;
116                 } else {
117                     augIData[i][j] = 0d;
118                 }
119             }
120         }
121         
122         // Compute and return Hat matrix
123         return Q.multiply(augI).multiply(Q.transpose());
124     }
125    
126     /**
127      * Loads new x sample data, overriding any previous sample
128      * 
129      * @param x the [n,k] array representing the x sample
130      */
131     @Override
132     protected void newXSampleData(double[][] x) {
133         this.X = new Array2DRowRealMatrix(x);
134         qr = new QRDecompositionImpl(X);
135     }
136     
137     /**
138      * Calculates regression coefficients using OLS.
139      * 
140      * @return beta
141      */
142     @Override
143     protected RealVector calculateBeta() {
144         return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
145     }
146 
147     /**
148      * <p>Calculates the variance on the beta by OLS.
149      * </p>
150      * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
151      * </p>
152      * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
153      * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
154      * R included, where p = the length of the beta vector.</p> 
155      * 
156      * @return The beta variance
157      */
158     @Override
159     protected RealMatrix calculateBetaVariance() {
160         int p = X.getColumnDimension();
161         RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
162         RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
163         return Rinv.multiply(Rinv.transpose());
164     }
165     
166 
167     /**
168      * <p>Calculates the variance on the Y by OLS.
169      * </p>
170      * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
171      * </p>
172      * @return The Y variance
173      */
174     @Override
175     protected double calculateYVariance() {
176         RealVector residuals = calculateResiduals();
177         return residuals.dotProduct(residuals) /
178                (X.getRowDimension() - X.getColumnDimension());
179     }
180     
181     /** TODO:  Find a home for the following methods in the linear package */   
182     
183     /**
184      * <p>Uses back substitution to solve the system</p>
185      * 
186      * <p>coefficients X = constants</p>
187      * 
188      * <p>coefficients must upper-triangular and constants must be a column 
189      * matrix.  The solution is returned as a column matrix.</p>
190      * 
191      * <p>The number of columns in coefficients determines the length
192      * of the returned solution vector (column matrix).  If constants
193      * has more rows than coefficients has columns, excess rows are ignored.
194      * Similarly, extra (zero) rows in coefficients are ignored</p>
195      * 
196      * @param coefficients upper-triangular coefficients matrix
197      * @param constants column RHS constants vector
198      * @return solution matrix as a column vector
199      * 
200      */
201     private static RealVector solveUpperTriangular(RealMatrix coefficients,
202                                                    RealVector constants) {
203         checkUpperTriangular(coefficients, 1E-12);
204         int length = coefficients.getColumnDimension();
205         double x[] = new double[length];
206         for (int i = 0; i < length; i++) {
207             int index = length - 1 - i;
208             double sum = 0;
209             for (int j = index + 1; j < length; j++) {
210                 sum += coefficients.getEntry(index, j) * x[j];
211             }
212             x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
213         } 
214         return new ArrayRealVector(x);
215     }
216     
217     /**
218      * <p>Check if a matrix is upper-triangular.</p>
219      * 
220      * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p>
221      * 
222      * @param m matrix to check
223      * @param epsilon maximum allowable absolute value for elements below
224      * the main diagonal
225      * 
226      * @throws IllegalArgumentException if m is not upper-triangular
227      */
228     private static void checkUpperTriangular(RealMatrix m, double epsilon) {
229         int nCols = m.getColumnDimension();
230         int nRows = m.getRowDimension();
231         for (int r = 0; r < nRows; r++) {
232             int bound = Math.min(r, nCols);
233             for (int c = 0; c < bound; c++) {
234                 if (Math.abs(m.getEntry(r, c)) > epsilon) {
235                     throw MathRuntimeException.createIllegalArgumentException(
236                           "matrix is not upper-triangular, entry ({0}, {1}) = {2} is too large",
237                           r, c, m.getEntry(r, c));
238                 }
239             }
240         }
241     }
242 }