001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import java.lang.reflect.Array;
021    
022    import org.apache.commons.math.Field;
023    import org.apache.commons.math.FieldElement;
024    import org.apache.commons.math.MathRuntimeException;
025    import org.apache.commons.math.exception.util.LocalizedFormats;
026    
027    /**
028     * Calculates the LUP-decomposition of a square matrix.
029     * <p>The LUP-decomposition of a matrix A consists of three matrices
030     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
031     * upper triangular and P is a permutation matrix. All matrices are
032     * m&times;m.</p>
033     * <p>Since {@link FieldElement field elements} do not provide an ordering
034     * operator, the permutation matrix is computed here only in order to avoid
035     * a zero pivot element, no attempt is done to get the largest pivot element.</p>
036     *
037     * @param <T> the type of the field elements
038     * @version $Revision: 983921 $ $Date: 2010-08-10 12:46:06 +0200 (mar. 10 ao??t 2010) $
039     * @since 2.0
040     */
041    public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
042    
043        /** Field to which the elements belong. */
044        private final Field<T> field;
045    
046        /** Entries of LU decomposition. */
047        private T lu[][];
048    
049        /** Pivot permutation associated with LU decomposition */
050        private int[] pivot;
051    
052        /** Parity of the permutation associated with the LU decomposition */
053        private boolean even;
054    
055        /** Singularity indicator. */
056        private boolean singular;
057    
058        /** Cached value of L. */
059        private FieldMatrix<T> cachedL;
060    
061        /** Cached value of U. */
062        private FieldMatrix<T> cachedU;
063    
064        /** Cached value of P. */
065        private FieldMatrix<T> cachedP;
066    
067        /**
068         * Calculates the LU-decomposition of the given matrix.
069         * @param matrix The matrix to decompose.
070         * @exception NonSquareMatrixException if matrix is not square
071         */
072        public FieldLUDecompositionImpl(FieldMatrix<T> matrix)
073            throws NonSquareMatrixException {
074    
075            if (!matrix.isSquare()) {
076                throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
077            }
078    
079            final int m = matrix.getColumnDimension();
080            field = matrix.getField();
081            lu = matrix.getData();
082            pivot = new int[m];
083            cachedL = null;
084            cachedU = null;
085            cachedP = null;
086    
087            // Initialize permutation array and parity
088            for (int row = 0; row < m; row++) {
089                pivot[row] = row;
090            }
091            even     = true;
092            singular = false;
093    
094            // Loop over columns
095            for (int col = 0; col < m; col++) {
096    
097                T sum = field.getZero();
098    
099                // upper
100                for (int row = 0; row < col; row++) {
101                    final T[] luRow = lu[row];
102                    sum = luRow[col];
103                    for (int i = 0; i < row; i++) {
104                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
105                    }
106                    luRow[col] = sum;
107                }
108    
109                // lower
110                int nonZero = col; // permutation row
111                for (int row = col; row < m; row++) {
112                    final T[] luRow = lu[row];
113                    sum = luRow[col];
114                    for (int i = 0; i < col; i++) {
115                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116                    }
117                    luRow[col] = sum;
118    
119                    if (lu[nonZero][col].equals(field.getZero())) {
120                        // try to select a better permutation choice
121                        ++nonZero;
122                    }
123                }
124    
125                // Singularity check
126                if (nonZero >= m) {
127                    singular = true;
128                    return;
129                }
130    
131                // Pivot if necessary
132                if (nonZero != col) {
133                    T tmp = field.getZero();
134                    for (int i = 0; i < m; i++) {
135                        tmp = lu[nonZero][i];
136                        lu[nonZero][i] = lu[col][i];
137                        lu[col][i] = tmp;
138                    }
139                    int temp = pivot[nonZero];
140                    pivot[nonZero] = pivot[col];
141                    pivot[col] = temp;
142                    even = !even;
143                }
144    
145                // Divide the lower elements by the "winning" diagonal elt.
146                final T luDiag = lu[col][col];
147                for (int row = col + 1; row < m; row++) {
148                    final T[] luRow = lu[row];
149                    luRow[col] = luRow[col].divide(luDiag);
150                }
151            }
152    
153        }
154    
155        /** {@inheritDoc} */
156        public FieldMatrix<T> getL() {
157            if ((cachedL == null) && !singular) {
158                final int m = pivot.length;
159                cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
160                for (int i = 0; i < m; ++i) {
161                    final T[] luI = lu[i];
162                    for (int j = 0; j < i; ++j) {
163                        cachedL.setEntry(i, j, luI[j]);
164                    }
165                    cachedL.setEntry(i, i, field.getOne());
166                }
167            }
168            return cachedL;
169        }
170    
171        /** {@inheritDoc} */
172        public FieldMatrix<T> getU() {
173            if ((cachedU == null) && !singular) {
174                final int m = pivot.length;
175                cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
176                for (int i = 0; i < m; ++i) {
177                    final T[] luI = lu[i];
178                    for (int j = i; j < m; ++j) {
179                        cachedU.setEntry(i, j, luI[j]);
180                    }
181                }
182            }
183            return cachedU;
184        }
185    
186        /** {@inheritDoc} */
187        public FieldMatrix<T> getP() {
188            if ((cachedP == null) && !singular) {
189                final int m = pivot.length;
190                cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
191                for (int i = 0; i < m; ++i) {
192                    cachedP.setEntry(i, pivot[i], field.getOne());
193                }
194            }
195            return cachedP;
196        }
197    
198        /** {@inheritDoc} */
199        public int[] getPivot() {
200            return pivot.clone();
201        }
202    
203        /** {@inheritDoc} */
204        public T getDeterminant() {
205            if (singular) {
206                return field.getZero();
207            } else {
208                final int m = pivot.length;
209                T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
210                for (int i = 0; i < m; i++) {
211                    determinant = determinant.multiply(lu[i][i]);
212                }
213                return determinant;
214            }
215        }
216    
217        /** {@inheritDoc} */
218        public FieldDecompositionSolver<T> getSolver() {
219            return new Solver<T>(field, lu, pivot, singular);
220        }
221    
222        /** Specialized solver. */
223        private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
224    
225            /** Serializable version identifier. */
226            private static final long serialVersionUID = -6353105415121373022L;
227    
228            /** Field to which the elements belong. */
229            private final Field<T> field;
230    
231            /** Entries of LU decomposition. */
232            private final T lu[][];
233    
234            /** Pivot permutation associated with LU decomposition. */
235            private final int[] pivot;
236    
237            /** Singularity indicator. */
238            private final boolean singular;
239    
240            /**
241             * Build a solver from decomposed matrix.
242             * @param field field to which the matrix elements belong
243             * @param lu entries of LU decomposition
244             * @param pivot pivot permutation associated with LU decomposition
245             * @param singular singularity indicator
246             */
247            private Solver(final Field<T> field, final T[][] lu,
248                           final int[] pivot, final boolean singular) {
249                this.field    = field;
250                this.lu       = lu;
251                this.pivot    = pivot;
252                this.singular = singular;
253            }
254    
255            /** {@inheritDoc} */
256            public boolean isNonSingular() {
257                return !singular;
258            }
259    
260            /** {@inheritDoc} */
261            public T[] solve(T[] b)
262                throws IllegalArgumentException, InvalidMatrixException {
263    
264                final int m = pivot.length;
265                if (b.length != m) {
266                    throw MathRuntimeException.createIllegalArgumentException(
267                            LocalizedFormats.VECTOR_LENGTH_MISMATCH,
268                            b.length, m);
269                }
270                if (singular) {
271                    throw new SingularMatrixException();
272                }
273    
274                @SuppressWarnings("unchecked") // field is of type T
275                final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
276    
277                // Apply permutations to b
278                for (int row = 0; row < m; row++) {
279                    bp[row] = b[pivot[row]];
280                }
281    
282                // Solve LY = b
283                for (int col = 0; col < m; col++) {
284                    final T bpCol = bp[col];
285                    for (int i = col + 1; i < m; i++) {
286                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
287                    }
288                }
289    
290                // Solve UX = Y
291                for (int col = m - 1; col >= 0; col--) {
292                    bp[col] = bp[col].divide(lu[col][col]);
293                    final T bpCol = bp[col];
294                    for (int i = 0; i < col; i++) {
295                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
296                    }
297                }
298    
299                return bp;
300    
301            }
302    
303            /** {@inheritDoc} */
304            public FieldVector<T> solve(FieldVector<T> b)
305                throws IllegalArgumentException, InvalidMatrixException {
306                try {
307                    return solve((ArrayFieldVector<T>) b);
308                } catch (ClassCastException cce) {
309    
310                    final int m = pivot.length;
311                    if (b.getDimension() != m) {
312                        throw MathRuntimeException.createIllegalArgumentException(
313                                LocalizedFormats.VECTOR_LENGTH_MISMATCH,
314                                b.getDimension(), m);
315                    }
316                    if (singular) {
317                        throw new SingularMatrixException();
318                    }
319    
320                    @SuppressWarnings("unchecked") // field is of type T
321                    final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
322    
323                    // Apply permutations to b
324                    for (int row = 0; row < m; row++) {
325                        bp[row] = b.getEntry(pivot[row]);
326                    }
327    
328                    // Solve LY = b
329                    for (int col = 0; col < m; col++) {
330                        final T bpCol = bp[col];
331                        for (int i = col + 1; i < m; i++) {
332                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
333                        }
334                    }
335    
336                    // Solve UX = Y
337                    for (int col = m - 1; col >= 0; col--) {
338                        bp[col] = bp[col].divide(lu[col][col]);
339                        final T bpCol = bp[col];
340                        for (int i = 0; i < col; i++) {
341                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
342                        }
343                    }
344    
345                    return new ArrayFieldVector<T>(bp, false);
346    
347                }
348            }
349    
350            /** Solve the linear equation A &times; X = B.
351             * <p>The A matrix is implicit here. It is </p>
352             * @param b right-hand side of the equation A &times; X = B
353             * @return a vector X such that A &times; X = B
354             * @exception IllegalArgumentException if matrices dimensions don't match
355             * @exception InvalidMatrixException if decomposed matrix is singular
356             */
357            public ArrayFieldVector<T> solve(ArrayFieldVector<T> b)
358                throws IllegalArgumentException, InvalidMatrixException {
359                return new ArrayFieldVector<T>(solve(b.getDataRef()), false);
360            }
361    
362            /** {@inheritDoc} */
363            public FieldMatrix<T> solve(FieldMatrix<T> b)
364                throws IllegalArgumentException, InvalidMatrixException {
365    
366                final int m = pivot.length;
367                if (b.getRowDimension() != m) {
368                    throw MathRuntimeException.createIllegalArgumentException(
369                            LocalizedFormats.DIMENSIONS_MISMATCH_2x2,
370                            b.getRowDimension(), b.getColumnDimension(), m, "n");
371                }
372                if (singular) {
373                    throw new SingularMatrixException();
374                }
375    
376                final int nColB = b.getColumnDimension();
377    
378                // Apply permutations to b
379                @SuppressWarnings("unchecked") // field is of type T
380                final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
381                for (int row = 0; row < m; row++) {
382                    final T[] bpRow = bp[row];
383                    final int pRow = pivot[row];
384                    for (int col = 0; col < nColB; col++) {
385                        bpRow[col] = b.getEntry(pRow, col);
386                    }
387                }
388    
389                // Solve LY = b
390                for (int col = 0; col < m; col++) {
391                    final T[] bpCol = bp[col];
392                    for (int i = col + 1; i < m; i++) {
393                        final T[] bpI = bp[i];
394                        final T luICol = lu[i][col];
395                        for (int j = 0; j < nColB; j++) {
396                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
397                        }
398                    }
399                }
400    
401                // Solve UX = Y
402                for (int col = m - 1; col >= 0; col--) {
403                    final T[] bpCol = bp[col];
404                    final T luDiag = lu[col][col];
405                    for (int j = 0; j < nColB; j++) {
406                        bpCol[j] = bpCol[j].divide(luDiag);
407                    }
408                    for (int i = 0; i < col; i++) {
409                        final T[] bpI = bp[i];
410                        final T luICol = lu[i][col];
411                        for (int j = 0; j < nColB; j++) {
412                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
413                        }
414                    }
415                }
416    
417                return new Array2DRowFieldMatrix<T>(bp, false);
418    
419            }
420    
421            /** {@inheritDoc} */
422            public FieldMatrix<T> getInverse() throws InvalidMatrixException {
423                final int m = pivot.length;
424                final T one = field.getOne();
425                FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
426                for (int i = 0; i < m; ++i) {
427                    identity.setEntry(i, i, one);
428                }
429                return solve(identity);
430            }
431    
432        }
433    
434    }