001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import org.apache.commons.math.MathRuntimeException;
021    
022    /**
023     * Calculates the LUP-decomposition of a square matrix.
024     * <p>The LUP-decomposition of a matrix A consists of three matrices
025     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
026     * upper triangular and P is a permutation matrix. All matrices are
027     * m&times;m.</p>
028     * <p>As shown by the presence of the P matrix, this decomposition is
029     * implemented using partial pivoting.</p>
030     *
031     * @version $Revision: 799857 $ $Date: 2009-08-01 09:07:12 -0400 (Sat, 01 Aug 2009) $
032     * @since 2.0
033     */
034    public class LUDecompositionImpl implements LUDecomposition {
035    
036        /** Entries of LU decomposition. */
037        private double lu[][];
038    
039        /** Pivot permutation associated with LU decomposition */
040        private int[] pivot;
041    
042        /** Parity of the permutation associated with the LU decomposition */
043        private boolean even;
044    
045        /** Singularity indicator. */
046        private boolean singular;
047    
048        /** Cached value of L. */
049        private RealMatrix cachedL;
050    
051        /** Cached value of U. */
052        private RealMatrix cachedU;
053    
054        /** Cached value of P. */
055        private RealMatrix cachedP;
056    
057        /** Default bound to determine effective singularity in LU decomposition */
058        private static final double DEFAULT_TOO_SMALL = 10E-12;
059    
060        /**
061         * Calculates the LU-decomposition of the given matrix. 
062         * @param matrix The matrix to decompose.
063         * @exception InvalidMatrixException if matrix is not square
064         */
065        public LUDecompositionImpl(RealMatrix matrix)
066            throws InvalidMatrixException {
067            this(matrix, DEFAULT_TOO_SMALL);
068        }
069    
070        /**
071         * Calculates the LU-decomposition of the given matrix. 
072         * @param matrix The matrix to decompose.
073         * @param singularityThreshold threshold (based on partial row norm)
074         * under which a matrix is considered singular
075         * @exception NonSquareMatrixException if matrix is not square
076         */
077        public LUDecompositionImpl(RealMatrix matrix, double singularityThreshold)
078            throws NonSquareMatrixException {
079    
080            if (!matrix.isSquare()) {
081                throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
082            }
083    
084            final int m = matrix.getColumnDimension();
085            lu = matrix.getData();
086            pivot = new int[m];
087            cachedL = null;
088            cachedU = null;
089            cachedP = null;
090    
091            // Initialize permutation array and parity
092            for (int row = 0; row < m; row++) {
093                pivot[row] = row;
094            }
095            even     = true;
096            singular = false;
097    
098            // Loop over columns
099            for (int col = 0; col < m; col++) {
100    
101                double sum = 0;
102    
103                // upper
104                for (int row = 0; row < col; row++) {
105                    final double[] luRow = lu[row];
106                    sum = luRow[col];
107                    for (int i = 0; i < row; i++) {
108                        sum -= luRow[i] * lu[i][col];
109                    }
110                    luRow[col] = sum;
111                }
112    
113                // lower
114                int max = col; // permutation row
115                double largest = Double.NEGATIVE_INFINITY;
116                for (int row = col; row < m; row++) {
117                    final double[] luRow = lu[row];
118                    sum = luRow[col];
119                    for (int i = 0; i < col; i++) {
120                        sum -= luRow[i] * lu[i][col];
121                    }
122                    luRow[col] = sum;
123    
124                    // maintain best permutation choice
125                    if (Math.abs(sum) > largest) {
126                        largest = Math.abs(sum);
127                        max = row;
128                    }
129                }
130    
131                // Singularity check
132                if (Math.abs(lu[max][col]) < singularityThreshold) {
133                    singular = true;
134                    return;
135                }
136    
137                // Pivot if necessary
138                if (max != col) {
139                    double tmp = 0;
140                    final double[] luMax = lu[max];
141                    final double[] luCol = lu[col];
142                    for (int i = 0; i < m; i++) {
143                        tmp = luMax[i];
144                        luMax[i] = luCol[i];
145                        luCol[i] = tmp;
146                    }
147                    int temp = pivot[max];
148                    pivot[max] = pivot[col];
149                    pivot[col] = temp;
150                    even = !even;
151                }
152    
153                // Divide the lower elements by the "winning" diagonal elt.
154                final double luDiag = lu[col][col];
155                for (int row = col + 1; row < m; row++) {
156                    lu[row][col] /= luDiag;
157                }
158            }
159    
160        }
161    
162        /** {@inheritDoc} */
163        public RealMatrix getL() {
164            if ((cachedL == null) && !singular) {
165                final int m = pivot.length;
166                cachedL = MatrixUtils.createRealMatrix(m, m);
167                for (int i = 0; i < m; ++i) {
168                    final double[] luI = lu[i];
169                    for (int j = 0; j < i; ++j) {
170                        cachedL.setEntry(i, j, luI[j]);
171                    }
172                    cachedL.setEntry(i, i, 1.0);
173                }
174            }
175            return cachedL;
176        }
177    
178        /** {@inheritDoc} */
179        public RealMatrix getU() {
180            if ((cachedU == null) && !singular) {
181                final int m = pivot.length;
182                cachedU = MatrixUtils.createRealMatrix(m, m);
183                for (int i = 0; i < m; ++i) {
184                    final double[] luI = lu[i];
185                    for (int j = i; j < m; ++j) {
186                        cachedU.setEntry(i, j, luI[j]);
187                    }
188                }
189            }
190            return cachedU;
191        }
192    
193        /** {@inheritDoc} */
194        public RealMatrix getP() {
195            if ((cachedP == null) && !singular) {
196                final int m = pivot.length;
197                cachedP = MatrixUtils.createRealMatrix(m, m);
198                for (int i = 0; i < m; ++i) {
199                    cachedP.setEntry(i, pivot[i], 1.0);
200                }
201            }
202            return cachedP;
203        }
204    
205        /** {@inheritDoc} */
206        public int[] getPivot() {
207            return pivot.clone();
208        }
209    
210        /** {@inheritDoc} */
211        public double getDeterminant() {
212            if (singular) {
213                return 0;
214            } else {
215                final int m = pivot.length;
216                double determinant = even ? 1 : -1;
217                for (int i = 0; i < m; i++) {
218                    determinant *= lu[i][i];
219                }
220                return determinant;
221            }
222        }
223    
224        /** {@inheritDoc} */
225        public DecompositionSolver getSolver() {
226            return new Solver(lu, pivot, singular);
227        }
228    
229        /** Specialized solver. */
230        private static class Solver implements DecompositionSolver {
231        
232            /** Entries of LU decomposition. */
233            private final double lu[][];
234    
235            /** Pivot permutation associated with LU decomposition. */
236            private final int[] pivot;
237    
238            /** Singularity indicator. */
239            private final boolean singular;
240    
241            /**
242             * Build a solver from decomposed matrix.
243             * @param lu entries of LU decomposition
244             * @param pivot pivot permutation associated with LU decomposition
245             * @param singular singularity indicator
246             */
247            private Solver(final double[][] lu, final int[] pivot, final boolean singular) {
248                this.lu       = lu;
249                this.pivot    = pivot;
250                this.singular = singular;
251            }
252    
253            /** {@inheritDoc} */
254            public boolean isNonSingular() {
255                return !singular;
256            }
257    
258            /** {@inheritDoc} */
259            public double[] solve(double[] b)
260                throws IllegalArgumentException, InvalidMatrixException {
261    
262                final int m = pivot.length;
263                if (b.length != m) {
264                    throw MathRuntimeException.createIllegalArgumentException(
265                            "vector length mismatch: got {0} but expected {1}",
266                            b.length, m);
267                }
268                if (singular) {
269                    throw new SingularMatrixException();
270                }
271    
272                final double[] bp = new double[m];
273    
274                // Apply permutations to b
275                for (int row = 0; row < m; row++) {
276                    bp[row] = b[pivot[row]];
277                }
278    
279                // Solve LY = b
280                for (int col = 0; col < m; col++) {
281                    final double bpCol = bp[col];
282                    for (int i = col + 1; i < m; i++) {
283                        bp[i] -= bpCol * lu[i][col];
284                    }
285                }
286    
287                // Solve UX = Y
288                for (int col = m - 1; col >= 0; col--) {
289                    bp[col] /= lu[col][col];
290                    final double bpCol = bp[col];
291                    for (int i = 0; i < col; i++) {
292                        bp[i] -= bpCol * lu[i][col];
293                    }
294                }
295    
296                return bp;
297    
298            }
299    
300            /** {@inheritDoc} */
301            public RealVector solve(RealVector b)
302                throws IllegalArgumentException, InvalidMatrixException {
303                try {
304                    return solve((ArrayRealVector) b);
305                } catch (ClassCastException cce) {
306    
307                    final int m = pivot.length;
308                    if (b.getDimension() != m) {
309                        throw MathRuntimeException.createIllegalArgumentException(
310                                "vector length mismatch: got {0} but expected {1}",
311                                b.getDimension(), m);
312                    }
313                    if (singular) {
314                        throw new SingularMatrixException();
315                    }
316    
317                    final double[] bp = new double[m];
318    
319                    // Apply permutations to b
320                    for (int row = 0; row < m; row++) {
321                        bp[row] = b.getEntry(pivot[row]);
322                    }
323    
324                    // Solve LY = b
325                    for (int col = 0; col < m; col++) {
326                        final double bpCol = bp[col];
327                        for (int i = col + 1; i < m; i++) {
328                            bp[i] -= bpCol * lu[i][col];
329                        }
330                    }
331    
332                    // Solve UX = Y
333                    for (int col = m - 1; col >= 0; col--) {
334                        bp[col] /= lu[col][col];
335                        final double bpCol = bp[col];
336                        for (int i = 0; i < col; i++) {
337                            bp[i] -= bpCol * lu[i][col];
338                        }
339                    }
340    
341                    return new ArrayRealVector(bp, false);
342    
343                }
344            }
345    
346            /** Solve the linear equation A &times; X = B.
347             * <p>The A matrix is implicit here. It is </p>
348             * @param b right-hand side of the equation A &times; X = B
349             * @return a vector X such that A &times; X = B
350             * @exception IllegalArgumentException if matrices dimensions don't match
351             * @exception InvalidMatrixException if decomposed matrix is singular
352             */
353            public ArrayRealVector solve(ArrayRealVector b)
354                throws IllegalArgumentException, InvalidMatrixException {
355                return new ArrayRealVector(solve(b.getDataRef()), false);
356            }
357    
358            /** {@inheritDoc} */
359            public RealMatrix solve(RealMatrix b)
360                throws IllegalArgumentException, InvalidMatrixException {
361    
362                final int m = pivot.length;
363                if (b.getRowDimension() != m) {
364                    throw MathRuntimeException.createIllegalArgumentException(
365                            "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
366                            b.getRowDimension(), b.getColumnDimension(), m, "n");
367                }
368                if (singular) {
369                    throw new SingularMatrixException();
370                }
371    
372                final int nColB = b.getColumnDimension();
373    
374                // Apply permutations to b
375                final double[][] bp = new double[m][nColB];
376                for (int row = 0; row < m; row++) {
377                    final double[] bpRow = bp[row];
378                    final int pRow = pivot[row];
379                    for (int col = 0; col < nColB; col++) {
380                        bpRow[col] = b.getEntry(pRow, col);
381                    }
382                }
383    
384                // Solve LY = b
385                for (int col = 0; col < m; col++) {
386                    final double[] bpCol = bp[col];
387                    for (int i = col + 1; i < m; i++) {
388                        final double[] bpI = bp[i];
389                        final double luICol = lu[i][col];
390                        for (int j = 0; j < nColB; j++) {
391                            bpI[j] -= bpCol[j] * luICol;
392                        }
393                    }
394                }
395    
396                // Solve UX = Y
397                for (int col = m - 1; col >= 0; col--) {
398                    final double[] bpCol = bp[col];
399                    final double luDiag = lu[col][col];
400                    for (int j = 0; j < nColB; j++) {
401                        bpCol[j] /= luDiag;
402                    }
403                    for (int i = 0; i < col; i++) {
404                        final double[] bpI = bp[i];
405                        final double luICol = lu[i][col];
406                        for (int j = 0; j < nColB; j++) {
407                            bpI[j] -= bpCol[j] * luICol;
408                        }
409                    }
410                }
411    
412                return new Array2DRowRealMatrix(bp, false);
413    
414            }
415    
416            /** {@inheritDoc} */
417            public RealMatrix getInverse() throws InvalidMatrixException {
418                return solve(MatrixUtils.createRealIdentityMatrix(pivot.length));
419            }
420    
421        }
422    
423    }