001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import java.lang.reflect.Array;
021    
022    import org.apache.commons.math.Field;
023    import org.apache.commons.math.FieldElement;
024    import org.apache.commons.math.MathRuntimeException;
025    
026    /**
027     * Calculates the LUP-decomposition of a square matrix.
028     * <p>The LUP-decomposition of a matrix A consists of three matrices
029     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
030     * upper triangular and P is a permutation matrix. All matrices are
031     * m&times;m.</p>
032     * <p>Since {@link FieldElement field elements} do not provide an ordering
033     * operator, the permutation matrix is computed here only in order to avoid
034     * a zero pivot element, no attempt is done to get the largest pivot element.</p>
035     *
036     * @param <T> the type of the field elements
037     * @version $Revision: 783702 $ $Date: 2009-06-11 04:54:02 -0400 (Thu, 11 Jun 2009) $
038     * @since 2.0
039     */
040    public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
041    
042        /** Field to which the elements belong. */
043        private final Field<T> field;
044    
045        /** Entries of LU decomposition. */
046        private T lu[][];
047    
048        /** Pivot permutation associated with LU decomposition */
049        private int[] pivot;
050    
051        /** Parity of the permutation associated with the LU decomposition */
052        private boolean even;
053    
054        /** Singularity indicator. */
055        private boolean singular;
056    
057        /** Cached value of L. */
058        private FieldMatrix<T> cachedL;
059    
060        /** Cached value of U. */
061        private FieldMatrix<T> cachedU;
062    
063        /** Cached value of P. */
064        private FieldMatrix<T> cachedP;
065    
066        /**
067         * Calculates the LU-decomposition of the given matrix. 
068         * @param matrix The matrix to decompose.
069         * @exception NonSquareMatrixException if matrix is not square
070         */
071        public FieldLUDecompositionImpl(FieldMatrix<T> matrix)
072            throws NonSquareMatrixException {
073    
074            if (!matrix.isSquare()) {
075                throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
076            }
077    
078            final int m = matrix.getColumnDimension();
079            field = matrix.getField();
080            lu = matrix.getData();
081            pivot = new int[m];
082            cachedL = null;
083            cachedU = null;
084            cachedP = null;
085    
086            // Initialize permutation array and parity
087            for (int row = 0; row < m; row++) {
088                pivot[row] = row;
089            }
090            even     = true;
091            singular = false;
092    
093            // Loop over columns
094            for (int col = 0; col < m; col++) {
095    
096                T sum = field.getZero();
097    
098                // upper
099                for (int row = 0; row < col; row++) {
100                    final T[] luRow = lu[row];
101                    sum = luRow[col];
102                    for (int i = 0; i < row; i++) {
103                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
104                    }
105                    luRow[col] = sum;
106                }
107    
108                // lower
109                int nonZero = col; // permutation row
110                for (int row = col; row < m; row++) {
111                    final T[] luRow = lu[row];
112                    sum = luRow[col];
113                    for (int i = 0; i < col; i++) {
114                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
115                    }
116                    luRow[col] = sum;
117    
118                    if (lu[nonZero][col].equals(field.getZero())) {
119                        // try to select a better permutation choice
120                        ++nonZero;
121                    }
122                }
123    
124                // Singularity check
125                if (nonZero >= m) {
126                    singular = true;
127                    return;
128                }
129    
130                // Pivot if necessary
131                if (nonZero != col) {
132                    T tmp = field.getZero();
133                    for (int i = 0; i < m; i++) {
134                        tmp = lu[nonZero][i];
135                        lu[nonZero][i] = lu[col][i];
136                        lu[col][i] = tmp;
137                    }
138                    int temp = pivot[nonZero];
139                    pivot[nonZero] = pivot[col];
140                    pivot[col] = temp;
141                    even = !even;
142                }
143    
144                // Divide the lower elements by the "winning" diagonal elt.
145                final T luDiag = lu[col][col];
146                for (int row = col + 1; row < m; row++) {
147                    final T[] luRow = lu[row];
148                    luRow[col] = luRow[col].divide(luDiag);
149                }
150            }
151    
152        }
153    
154        /** {@inheritDoc} */
155        public FieldMatrix<T> getL() {
156            if ((cachedL == null) && !singular) {
157                final int m = pivot.length;
158                cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
159                for (int i = 0; i < m; ++i) {
160                    final T[] luI = lu[i];
161                    for (int j = 0; j < i; ++j) {
162                        cachedL.setEntry(i, j, luI[j]);
163                    }
164                    cachedL.setEntry(i, i, field.getOne());
165                }
166            }
167            return cachedL;
168        }
169    
170        /** {@inheritDoc} */
171        public FieldMatrix<T> getU() {
172            if ((cachedU == null) && !singular) {
173                final int m = pivot.length;
174                cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
175                for (int i = 0; i < m; ++i) {
176                    final T[] luI = lu[i];
177                    for (int j = i; j < m; ++j) {
178                        cachedU.setEntry(i, j, luI[j]);
179                    }
180                }
181            }
182            return cachedU;
183        }
184    
185        /** {@inheritDoc} */
186        public FieldMatrix<T> getP() {
187            if ((cachedP == null) && !singular) {
188                final int m = pivot.length;
189                cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
190                for (int i = 0; i < m; ++i) {
191                    cachedP.setEntry(i, pivot[i], field.getOne());
192                }
193            }
194            return cachedP;
195        }
196    
197        /** {@inheritDoc} */
198        public int[] getPivot() {
199            return pivot.clone();
200        }
201    
202        /** {@inheritDoc} */
203        public T getDeterminant() {
204            if (singular) {
205                return field.getZero();
206            } else {
207                final int m = pivot.length;
208                T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
209                for (int i = 0; i < m; i++) {
210                    determinant = determinant.multiply(lu[i][i]);
211                }
212                return determinant;
213            }
214        }
215    
216        /** {@inheritDoc} */
217        public FieldDecompositionSolver<T> getSolver() {
218            return new Solver<T>(field, lu, pivot, singular);
219        }
220    
221        /** Specialized solver. */
222        private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
223    
224            /** Serializable version identifier. */
225            private static final long serialVersionUID = -6353105415121373022L;
226    
227            /** Field to which the elements belong. */
228            private final Field<T> field;
229    
230            /** Entries of LU decomposition. */
231            private final T lu[][];
232    
233            /** Pivot permutation associated with LU decomposition. */
234            private final int[] pivot;
235    
236            /** Singularity indicator. */
237            private final boolean singular;
238    
239            /**
240             * Build a solver from decomposed matrix.
241             * @param field field to which the matrix elements belong
242             * @param lu entries of LU decomposition
243             * @param pivot pivot permutation associated with LU decomposition
244             * @param singular singularity indicator
245             */
246            private Solver(final Field<T> field, final T[][] lu,
247                           final int[] pivot, final boolean singular) {
248                this.field    = field;
249                this.lu       = lu;
250                this.pivot    = pivot;
251                this.singular = singular;
252            }
253    
254            /** {@inheritDoc} */
255            public boolean isNonSingular() {
256                return !singular;
257            }
258    
259            /** {@inheritDoc} */
260            @SuppressWarnings("unchecked")
261            public T[] solve(T[] b)
262                throws IllegalArgumentException, InvalidMatrixException {
263    
264                final int m = pivot.length;
265                if (b.length != m) {
266                    throw MathRuntimeException.createIllegalArgumentException(
267                            "vector length mismatch: got {0} but expected {1}",
268                            b.length, m);
269                }
270                if (singular) {
271                    throw new SingularMatrixException();
272                }
273    
274                final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
275    
276                // Apply permutations to b
277                for (int row = 0; row < m; row++) {
278                    bp[row] = b[pivot[row]];
279                }
280    
281                // Solve LY = b
282                for (int col = 0; col < m; col++) {
283                    final T bpCol = bp[col];
284                    for (int i = col + 1; i < m; i++) {
285                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
286                    }
287                }
288    
289                // Solve UX = Y
290                for (int col = m - 1; col >= 0; col--) {
291                    bp[col] = bp[col].divide(lu[col][col]);
292                    final T bpCol = bp[col];
293                    for (int i = 0; i < col; i++) {
294                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
295                    }
296                }
297    
298                return bp;
299    
300            }
301    
302            /** {@inheritDoc} */
303            @SuppressWarnings("unchecked")
304            public FieldVector<T> solve(FieldVector<T> b)
305                throws IllegalArgumentException, InvalidMatrixException {
306                try {
307                    return solve((ArrayFieldVector<T>) b);
308                } catch (ClassCastException cce) {
309    
310                    final int m = pivot.length;
311                    if (b.getDimension() != m) {
312                        throw MathRuntimeException.createIllegalArgumentException(
313                                "vector length mismatch: got {0} but expected {1}",
314                                b.getDimension(), m);
315                    }
316                    if (singular) {
317                        throw new SingularMatrixException();
318                    }
319    
320                    final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
321    
322                    // Apply permutations to b
323                    for (int row = 0; row < m; row++) {
324                        bp[row] = b.getEntry(pivot[row]);
325                    }
326    
327                    // Solve LY = b
328                    for (int col = 0; col < m; col++) {
329                        final T bpCol = bp[col];
330                        for (int i = col + 1; i < m; i++) {
331                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
332                        }
333                    }
334    
335                    // Solve UX = Y
336                    for (int col = m - 1; col >= 0; col--) {
337                        bp[col] = bp[col].divide(lu[col][col]);
338                        final T bpCol = bp[col];
339                        for (int i = 0; i < col; i++) {
340                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
341                        }
342                    }
343    
344                    return new ArrayFieldVector<T>(bp, false);
345    
346                }
347            }
348    
349            /** Solve the linear equation A &times; X = B.
350             * <p>The A matrix is implicit here. It is </p>
351             * @param b right-hand side of the equation A &times; X = B
352             * @return a vector X such that A &times; X = B
353             * @exception IllegalArgumentException if matrices dimensions don't match
354             * @exception InvalidMatrixException if decomposed matrix is singular
355             */
356            public ArrayFieldVector<T> solve(ArrayFieldVector<T> b)
357                throws IllegalArgumentException, InvalidMatrixException {
358                return new ArrayFieldVector<T>(solve(b.getDataRef()), false);
359            }
360    
361            /** {@inheritDoc} */
362            @SuppressWarnings("unchecked")
363            public FieldMatrix<T> solve(FieldMatrix<T> b)
364                throws IllegalArgumentException, InvalidMatrixException {
365    
366                final int m = pivot.length;
367                if (b.getRowDimension() != m) {
368                    throw MathRuntimeException.createIllegalArgumentException(
369                            "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
370                            b.getRowDimension(), b.getColumnDimension(), m, "n");
371                }
372                if (singular) {
373                    throw new SingularMatrixException();
374                }
375    
376                final int nColB = b.getColumnDimension();
377    
378                // Apply permutations to b
379                final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
380                for (int row = 0; row < m; row++) {
381                    final T[] bpRow = bp[row];
382                    final int pRow = pivot[row];
383                    for (int col = 0; col < nColB; col++) {
384                        bpRow[col] = b.getEntry(pRow, col);
385                    }
386                }
387    
388                // Solve LY = b
389                for (int col = 0; col < m; col++) {
390                    final T[] bpCol = bp[col];
391                    for (int i = col + 1; i < m; i++) {
392                        final T[] bpI = bp[i];
393                        final T luICol = lu[i][col];
394                        for (int j = 0; j < nColB; j++) {
395                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
396                        }
397                    }
398                }
399    
400                // Solve UX = Y
401                for (int col = m - 1; col >= 0; col--) {
402                    final T[] bpCol = bp[col];
403                    final T luDiag = lu[col][col];
404                    for (int j = 0; j < nColB; j++) {
405                        bpCol[j] = bpCol[j].divide(luDiag);
406                    }
407                    for (int i = 0; i < col; i++) {
408                        final T[] bpI = bp[i];
409                        final T luICol = lu[i][col];
410                        for (int j = 0; j < nColB; j++) {
411                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
412                        }
413                    }
414                }
415    
416                return new Array2DRowFieldMatrix<T>(bp, false);
417    
418            }
419    
420            /** {@inheritDoc} */
421            public FieldMatrix<T> getInverse() throws InvalidMatrixException {
422                final int m = pivot.length;
423                final T one = field.getOne();
424                FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
425                for (int i = 0; i < m; ++i) {
426                    identity.setEntry(i, i, one);
427                }
428                return solve(identity);
429            }
430    
431        }
432    
433    }