1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.linear;
19
20 import java.util.Arrays;
21
22 import org.apache.commons.math.MathRuntimeException;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 public class QRDecompositionImpl implements QRDecomposition {
42
43
44
45
46
47
48
49 private double[][] qrt;
50
51
52 private double[] rDiag;
53
54
55 private RealMatrix cachedQ;
56
57
58 private RealMatrix cachedQT;
59
60
61 private RealMatrix cachedR;
62
63
64 private RealMatrix cachedH;
65
66
67
68
69
70 public QRDecompositionImpl(RealMatrix matrix) {
71
72 final int m = matrix.getRowDimension();
73 final int n = matrix.getColumnDimension();
74 qrt = matrix.transpose().getData();
75 rDiag = new double[Math.min(m, n)];
76 cachedQ = null;
77 cachedQT = null;
78 cachedR = null;
79 cachedH = null;
80
81
82
83
84
85
86 for (int minor = 0; minor < Math.min(m, n); minor++) {
87
88 final double[] qrtMinor = qrt[minor];
89
90
91
92
93
94
95
96
97 double xNormSqr = 0;
98 for (int row = minor; row < m; row++) {
99 final double c = qrtMinor[row];
100 xNormSqr += c * c;
101 }
102 final double a = (qrtMinor[minor] > 0) ? -Math.sqrt(xNormSqr) : Math.sqrt(xNormSqr);
103 rDiag[minor] = a;
104
105 if (a != 0.0) {
106
107
108
109
110
111
112
113
114
115 qrtMinor[minor] -= a;
116
117
118
119
120
121
122
123
124
125
126
127
128
129 for (int col = minor+1; col < n; col++) {
130 final double[] qrtCol = qrt[col];
131 double alpha = 0;
132 for (int row = minor; row < m; row++) {
133 alpha -= qrtCol[row] * qrtMinor[row];
134 }
135 alpha /= a * qrtMinor[minor];
136
137
138 for (int row = minor; row < m; row++) {
139 qrtCol[row] -= alpha * qrtMinor[row];
140 }
141 }
142 }
143 }
144 }
145
146
147 public RealMatrix getR() {
148
149 if (cachedR == null) {
150
151
152 final int n = qrt.length;
153 final int m = qrt[0].length;
154 cachedR = MatrixUtils.createRealMatrix(m, n);
155
156
157 for (int row = Math.min(m, n) - 1; row >= 0; row--) {
158 cachedR.setEntry(row, row, rDiag[row]);
159 for (int col = row + 1; col < n; col++) {
160 cachedR.setEntry(row, col, qrt[col][row]);
161 }
162 }
163
164 }
165
166
167 return cachedR;
168
169 }
170
171
172 public RealMatrix getQ() {
173 if (cachedQ == null) {
174 cachedQ = getQT().transpose();
175 }
176 return cachedQ;
177 }
178
179
180 public RealMatrix getQT() {
181
182 if (cachedQT == null) {
183
184
185 final int n = qrt.length;
186 final int m = qrt[0].length;
187 cachedQT = MatrixUtils.createRealMatrix(m, m);
188
189
190
191
192
193
194 for (int minor = m - 1; minor >= Math.min(m, n); minor--) {
195 cachedQT.setEntry(minor, minor, 1.0);
196 }
197
198 for (int minor = Math.min(m, n)-1; minor >= 0; minor--){
199 final double[] qrtMinor = qrt[minor];
200 cachedQT.setEntry(minor, minor, 1.0);
201 if (qrtMinor[minor] != 0.0) {
202 for (int col = minor; col < m; col++) {
203 double alpha = 0;
204 for (int row = minor; row < m; row++) {
205 alpha -= cachedQT.getEntry(col, row) * qrtMinor[row];
206 }
207 alpha /= rDiag[minor] * qrtMinor[minor];
208
209 for (int row = minor; row < m; row++) {
210 cachedQT.addToEntry(col, row, -alpha * qrtMinor[row]);
211 }
212 }
213 }
214 }
215
216 }
217
218
219 return cachedQT;
220
221 }
222
223
224 public RealMatrix getH() {
225
226 if (cachedH == null) {
227
228 final int n = qrt.length;
229 final int m = qrt[0].length;
230 cachedH = MatrixUtils.createRealMatrix(m, n);
231 for (int i = 0; i < m; ++i) {
232 for (int j = 0; j < Math.min(i + 1, n); ++j) {
233 cachedH.setEntry(i, j, qrt[j][i] / -rDiag[j]);
234 }
235 }
236
237 }
238
239
240 return cachedH;
241
242 }
243
244
245 public DecompositionSolver getSolver() {
246 return new Solver(qrt, rDiag);
247 }
248
249
250 private static class Solver implements DecompositionSolver {
251
252
253
254
255
256
257
258 private final double[][] qrt;
259
260
261 private final double[] rDiag;
262
263
264
265
266
267
268 private Solver(final double[][] qrt, final double[] rDiag) {
269 this.qrt = qrt;
270 this.rDiag = rDiag;
271 }
272
273
274 public boolean isNonSingular() {
275
276 for (double diag : rDiag) {
277 if (diag == 0) {
278 return false;
279 }
280 }
281 return true;
282
283 }
284
285
286 public double[] solve(double[] b)
287 throws IllegalArgumentException, InvalidMatrixException {
288
289 final int n = qrt.length;
290 final int m = qrt[0].length;
291 if (b.length != m) {
292 throw MathRuntimeException.createIllegalArgumentException(
293 "vector length mismatch: got {0} but expected {1}",
294 b.length, m);
295 }
296 if (!isNonSingular()) {
297 throw new SingularMatrixException();
298 }
299
300 final double[] x = new double[n];
301 final double[] y = b.clone();
302
303
304 for (int minor = 0; minor < Math.min(m, n); minor++) {
305
306 final double[] qrtMinor = qrt[minor];
307 double dotProduct = 0;
308 for (int row = minor; row < m; row++) {
309 dotProduct += y[row] * qrtMinor[row];
310 }
311 dotProduct /= rDiag[minor] * qrtMinor[minor];
312
313 for (int row = minor; row < m; row++) {
314 y[row] += dotProduct * qrtMinor[row];
315 }
316
317 }
318
319
320 for (int row = rDiag.length - 1; row >= 0; --row) {
321 y[row] /= rDiag[row];
322 final double yRow = y[row];
323 final double[] qrtRow = qrt[row];
324 x[row] = yRow;
325 for (int i = 0; i < row; i++) {
326 y[i] -= yRow * qrtRow[i];
327 }
328 }
329
330 return x;
331
332 }
333
334
335 public RealVector solve(RealVector b)
336 throws IllegalArgumentException, InvalidMatrixException {
337 try {
338 return solve((ArrayRealVector) b);
339 } catch (ClassCastException cce) {
340 return new ArrayRealVector(solve(b.getData()), false);
341 }
342 }
343
344
345
346
347
348
349
350
351 public ArrayRealVector solve(ArrayRealVector b)
352 throws IllegalArgumentException, InvalidMatrixException {
353 return new ArrayRealVector(solve(b.getDataRef()), false);
354 }
355
356
357 public RealMatrix solve(RealMatrix b)
358 throws IllegalArgumentException, InvalidMatrixException {
359
360 final int n = qrt.length;
361 final int m = qrt[0].length;
362 if (b.getRowDimension() != m) {
363 throw MathRuntimeException.createIllegalArgumentException(
364 "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
365 b.getRowDimension(), b.getColumnDimension(), m, "n");
366 }
367 if (!isNonSingular()) {
368 throw new SingularMatrixException();
369 }
370
371 final int columns = b.getColumnDimension();
372 final int blockSize = BlockRealMatrix.BLOCK_SIZE;
373 final int cBlocks = (columns + blockSize - 1) / blockSize;
374 final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
375 final double[][] y = new double[b.getRowDimension()][blockSize];
376 final double[] alpha = new double[blockSize];
377
378 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
379 final int kStart = kBlock * blockSize;
380 final int kEnd = Math.min(kStart + blockSize, columns);
381 final int kWidth = kEnd - kStart;
382
383
384 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
385
386
387 for (int minor = 0; minor < Math.min(m, n); minor++) {
388 final double[] qrtMinor = qrt[minor];
389 final double factor = 1.0 / (rDiag[minor] * qrtMinor[minor]);
390
391 Arrays.fill(alpha, 0, kWidth, 0.0);
392 for (int row = minor; row < m; ++row) {
393 final double d = qrtMinor[row];
394 final double[] yRow = y[row];
395 for (int k = 0; k < kWidth; ++k) {
396 alpha[k] += d * yRow[k];
397 }
398 }
399 for (int k = 0; k < kWidth; ++k) {
400 alpha[k] *= factor;
401 }
402
403 for (int row = minor; row < m; ++row) {
404 final double d = qrtMinor[row];
405 final double[] yRow = y[row];
406 for (int k = 0; k < kWidth; ++k) {
407 yRow[k] += alpha[k] * d;
408 }
409 }
410
411 }
412
413
414 for (int j = rDiag.length - 1; j >= 0; --j) {
415 final int jBlock = j / blockSize;
416 final int jStart = jBlock * blockSize;
417 final double factor = 1.0 / rDiag[j];
418 final double[] yJ = y[j];
419 final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
420 for (int k = 0, index = (j - jStart) * kWidth; k < kWidth; ++k, ++index) {
421 yJ[k] *= factor;
422 xBlock[index] = yJ[k];
423 }
424
425 final double[] qrtJ = qrt[j];
426 for (int i = 0; i < j; ++i) {
427 final double rIJ = qrtJ[i];
428 final double[] yI = y[i];
429 for (int k = 0; k < kWidth; ++k) {
430 yI[k] -= yJ[k] * rIJ;
431 }
432 }
433
434 }
435
436 }
437
438 return new BlockRealMatrix(n, columns, xBlocks, false);
439
440 }
441
442
443 public RealMatrix getInverse()
444 throws InvalidMatrixException {
445 return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
446 }
447
448 }
449
450 }