001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.regression;
018    
019    import org.apache.commons.math.MathRuntimeException;
020    import org.apache.commons.math.linear.LUDecompositionImpl;
021    import org.apache.commons.math.linear.QRDecomposition;
022    import org.apache.commons.math.linear.QRDecompositionImpl;
023    import org.apache.commons.math.linear.RealMatrix;
024    import org.apache.commons.math.linear.Array2DRowRealMatrix;
025    import org.apache.commons.math.linear.RealVector;
026    import org.apache.commons.math.linear.ArrayRealVector;
027    
028    /**
029     * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 
030     * multiple linear regression model.</p>
031     * 
032     * <p>OLS assumes the covariance matrix of the error to be diagonal and with
033     * equal variance.</p>
034     * <p>
035     * u ~ N(0, &sigma;<sup>2</sup>I)
036     * </p>
037     * 
038     * <p>The regression coefficients, b, satisfy the normal equations:
039     * <p>
040     * X<sup>T</sup> X b = X<sup>T</sup> y
041     * </p>
042     * 
043     * <p>To solve the normal equations, this implementation uses QR decomposition
044     * of the X matrix. (See {@link QRDecompositionImpl} for details on the
045     * decomposition algorithm.)
046     * </p>
047     * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/>
048     * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/>
049     * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
050     * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
051     * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
052     * R b = Q<sup>T</sup> y
053     * </p>
054     * Given Q and R, the last equation is solved by back-subsitution.</p>
055     * 
056     * @version $Revision: 783702 $ $Date: 2009-06-11 04:54:02 -0400 (Thu, 11 Jun 2009) $
057     * @since 2.0
058     */
059    public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
060        
061        /** Cached QR decomposition of X matrix */
062        private QRDecomposition qr = null;
063    
064        /**
065         * Loads model x and y sample data, overriding any previous sample.
066         * 
067         * Computes and caches QR decomposition of the X matrix.
068         * @param y the [n,1] array representing the y sample
069         * @param x the [n,k] array representing the x sample
070         * @throws IllegalArgumentException if the x and y array data are not
071         *             compatible for the regression
072         */
073        public void newSampleData(double[] y, double[][] x) {
074            validateSampleData(x, y);
075            newYSampleData(y);
076            newXSampleData(x);
077        }
078        
079        /**
080         * {@inheritDoc}
081         * 
082         * Computes and caches QR decomposition of the X matrix
083         */
084        @Override
085        public void newSampleData(double[] data, int nobs, int nvars) {
086            super.newSampleData(data, nobs, nvars);
087            qr = new QRDecompositionImpl(X);
088        }
089        
090        /**
091         * <p>Compute the "hat" matrix.
092         * </p>
093         * <p>The hat matrix is defined in terms of the design matrix X
094         *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
095         * </p>
096         * <p>The implementation here uses the QR decomposition to compute the
097         * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
098         * p-dimensional identity matrix augmented by 0's.  This computational
099         * formula is from "The Hat Matrix in Regression and ANOVA",
100         * David C. Hoaglin and Roy E. Welsch, 
101         * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
102         * 
103         * @return the hat matrix
104         */
105        public RealMatrix calculateHat() {
106            // Create augmented identity matrix
107            RealMatrix Q = qr.getQ();
108            final int p = qr.getR().getColumnDimension();
109            final int n = Q.getColumnDimension();
110            Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
111            double[][] augIData = augI.getDataRef();
112            for (int i = 0; i < n; i++) {
113                for (int j =0; j < n; j++) {
114                    if (i == j && i < p) {
115                        augIData[i][j] = 1d;
116                    } else {
117                        augIData[i][j] = 0d;
118                    }
119                }
120            }
121            
122            // Compute and return Hat matrix
123            return Q.multiply(augI).multiply(Q.transpose());
124        }
125       
126        /**
127         * Loads new x sample data, overriding any previous sample
128         * 
129         * @param x the [n,k] array representing the x sample
130         */
131        @Override
132        protected void newXSampleData(double[][] x) {
133            this.X = new Array2DRowRealMatrix(x);
134            qr = new QRDecompositionImpl(X);
135        }
136        
137        /**
138         * Calculates regression coefficients using OLS.
139         * 
140         * @return beta
141         */
142        @Override
143        protected RealVector calculateBeta() {
144            return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
145        }
146    
147        /**
148         * <p>Calculates the variance on the beta by OLS.
149         * </p>
150         * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
151         * </p>
152         * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
153         * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
154         * R included, where p = the length of the beta vector.</p> 
155         * 
156         * @return The beta variance
157         */
158        @Override
159        protected RealMatrix calculateBetaVariance() {
160            int p = X.getColumnDimension();
161            RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
162            RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
163            return Rinv.multiply(Rinv.transpose());
164        }
165        
166    
167        /**
168         * <p>Calculates the variance on the Y by OLS.
169         * </p>
170         * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
171         * </p>
172         * @return The Y variance
173         */
174        @Override
175        protected double calculateYVariance() {
176            RealVector residuals = calculateResiduals();
177            return residuals.dotProduct(residuals) /
178                   (X.getRowDimension() - X.getColumnDimension());
179        }
180        
181        /** TODO:  Find a home for the following methods in the linear package */   
182        
183        /**
184         * <p>Uses back substitution to solve the system</p>
185         * 
186         * <p>coefficients X = constants</p>
187         * 
188         * <p>coefficients must upper-triangular and constants must be a column 
189         * matrix.  The solution is returned as a column matrix.</p>
190         * 
191         * <p>The number of columns in coefficients determines the length
192         * of the returned solution vector (column matrix).  If constants
193         * has more rows than coefficients has columns, excess rows are ignored.
194         * Similarly, extra (zero) rows in coefficients are ignored</p>
195         * 
196         * @param coefficients upper-triangular coefficients matrix
197         * @param constants column RHS constants vector
198         * @return solution matrix as a column vector
199         * 
200         */
201        private static RealVector solveUpperTriangular(RealMatrix coefficients,
202                                                       RealVector constants) {
203            checkUpperTriangular(coefficients, 1E-12);
204            int length = coefficients.getColumnDimension();
205            double x[] = new double[length];
206            for (int i = 0; i < length; i++) {
207                int index = length - 1 - i;
208                double sum = 0;
209                for (int j = index + 1; j < length; j++) {
210                    sum += coefficients.getEntry(index, j) * x[j];
211                }
212                x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
213            } 
214            return new ArrayRealVector(x);
215        }
216        
217        /**
218         * <p>Check if a matrix is upper-triangular.</p>
219         * 
220         * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p>
221         * 
222         * @param m matrix to check
223         * @param epsilon maximum allowable absolute value for elements below
224         * the main diagonal
225         * 
226         * @throws IllegalArgumentException if m is not upper-triangular
227         */
228        private static void checkUpperTriangular(RealMatrix m, double epsilon) {
229            int nCols = m.getColumnDimension();
230            int nRows = m.getRowDimension();
231            for (int r = 0; r < nRows; r++) {
232                int bound = Math.min(r, nCols);
233                for (int c = 0; c < bound; c++) {
234                    if (Math.abs(m.getEntry(r, c)) > epsilon) {
235                        throw MathRuntimeException.createIllegalArgumentException(
236                              "matrix is not upper-triangular, entry ({0}, {1}) = {2} is too large",
237                              r, c, m.getEntry(r, c));
238                    }
239                }
240            }
241        }
242    }