1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.linear;
19
20 import org.apache.commons.math.MathRuntimeException;
21
22
23
24
25
26
27
28
29
30
31
32
33
34 public class CholeskyDecompositionImpl implements CholeskyDecomposition {
35
36
37
38 public static final double DEFAULT_RELATIVE_SYMMETRY_THRESHOLD = 1.0e-15;
39
40
41
42 public static final double DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD = 1.0e-10;
43
44
45 private double[][] lTData;
46
47
48 private RealMatrix cachedL;
49
50
51 private RealMatrix cachedLT;
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 public CholeskyDecompositionImpl(final RealMatrix matrix)
72 throws NonSquareMatrixException,
73 NotSymmetricMatrixException, NotPositiveDefiniteMatrixException {
74 this(matrix, DEFAULT_RELATIVE_SYMMETRY_THRESHOLD,
75 DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
76 }
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93 public CholeskyDecompositionImpl(final RealMatrix matrix,
94 final double relativeSymmetryThreshold,
95 final double absolutePositivityThreshold)
96 throws NonSquareMatrixException,
97 NotSymmetricMatrixException, NotPositiveDefiniteMatrixException {
98
99 if (!matrix.isSquare()) {
100 throw new NonSquareMatrixException(matrix.getRowDimension(),
101 matrix.getColumnDimension());
102 }
103
104 final int order = matrix.getRowDimension();
105 lTData = matrix.getData();
106 cachedL = null;
107 cachedLT = null;
108
109
110 for (int i = 0; i < order; ++i) {
111
112 final double[] lI = lTData[i];
113
114
115 for (int j = i + 1; j < order; ++j) {
116 final double[] lJ = lTData[j];
117 final double lIJ = lI[j];
118 final double lJI = lJ[i];
119 final double maxDelta =
120 relativeSymmetryThreshold * Math.max(Math.abs(lIJ), Math.abs(lJI));
121 if (Math.abs(lIJ - lJI) > maxDelta) {
122 throw new NotSymmetricMatrixException();
123 }
124 lJ[i] = 0;
125 }
126 }
127
128
129 for (int i = 0; i < order; ++i) {
130
131 final double[] ltI = lTData[i];
132
133
134 if (ltI[i] < absolutePositivityThreshold) {
135 throw new NotPositiveDefiniteMatrixException();
136 }
137
138 ltI[i] = Math.sqrt(ltI[i]);
139 final double inverse = 1.0 / ltI[i];
140
141 for (int q = order - 1; q > i; --q) {
142 ltI[q] *= inverse;
143 final double[] ltQ = lTData[q];
144 for (int p = q; p < order; ++p) {
145 ltQ[p] -= ltI[q] * ltI[p];
146 }
147 }
148
149 }
150
151 }
152
153
154 public RealMatrix getL() {
155 if (cachedL == null) {
156 cachedL = getLT().transpose();
157 }
158 return cachedL;
159 }
160
161
162 public RealMatrix getLT() {
163
164 if (cachedLT == null) {
165 cachedLT = MatrixUtils.createRealMatrix(lTData);
166 }
167
168
169 return cachedLT;
170
171 }
172
173
174 public double getDeterminant() {
175 double determinant = 1.0;
176 for (int i = 0; i < lTData.length; ++i) {
177 double lTii = lTData[i][i];
178 determinant *= lTii * lTii;
179 }
180 return determinant;
181 }
182
183
184 public DecompositionSolver getSolver() {
185 return new Solver(lTData);
186 }
187
188
189 private static class Solver implements DecompositionSolver {
190
191
192 private final double[][] lTData;
193
194
195
196
197
198 private Solver(final double[][] lTData) {
199 this.lTData = lTData;
200 }
201
202
203 public boolean isNonSingular() {
204
205 return true;
206 }
207
208
209 public double[] solve(double[] b)
210 throws IllegalArgumentException, InvalidMatrixException {
211
212 final int m = lTData.length;
213 if (b.length != m) {
214 throw MathRuntimeException.createIllegalArgumentException(
215 "vector length mismatch: got {0} but expected {1}",
216 b.length, m);
217 }
218
219 final double[] x = b.clone();
220
221
222 for (int j = 0; j < m; j++) {
223 final double[] lJ = lTData[j];
224 x[j] /= lJ[j];
225 final double xJ = x[j];
226 for (int i = j + 1; i < m; i++) {
227 x[i] -= xJ * lJ[i];
228 }
229 }
230
231
232 for (int j = m - 1; j >= 0; j--) {
233 x[j] /= lTData[j][j];
234 final double xJ = x[j];
235 for (int i = 0; i < j; i++) {
236 x[i] -= xJ * lTData[i][j];
237 }
238 }
239
240 return x;
241
242 }
243
244
245 public RealVector solve(RealVector b)
246 throws IllegalArgumentException, InvalidMatrixException {
247 try {
248 return solve((ArrayRealVector) b);
249 } catch (ClassCastException cce) {
250
251 final int m = lTData.length;
252 if (b.getDimension() != m) {
253 throw MathRuntimeException.createIllegalArgumentException(
254 "vector length mismatch: got {0} but expected {1}",
255 b.getDimension(), m);
256 }
257
258 final double[] x = b.getData();
259
260
261 for (int j = 0; j < m; j++) {
262 final double[] lJ = lTData[j];
263 x[j] /= lJ[j];
264 final double xJ = x[j];
265 for (int i = j + 1; i < m; i++) {
266 x[i] -= xJ * lJ[i];
267 }
268 }
269
270
271 for (int j = m - 1; j >= 0; j--) {
272 x[j] /= lTData[j][j];
273 final double xJ = x[j];
274 for (int i = 0; i < j; i++) {
275 x[i] -= xJ * lTData[i][j];
276 }
277 }
278
279 return new ArrayRealVector(x, false);
280
281 }
282 }
283
284
285
286
287
288
289
290
291 public ArrayRealVector solve(ArrayRealVector b)
292 throws IllegalArgumentException, InvalidMatrixException {
293 return new ArrayRealVector(solve(b.getDataRef()), false);
294 }
295
296
297 public RealMatrix solve(RealMatrix b)
298 throws IllegalArgumentException, InvalidMatrixException {
299
300 final int m = lTData.length;
301 if (b.getRowDimension() != m) {
302 throw MathRuntimeException.createIllegalArgumentException(
303 "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
304 b.getRowDimension(), b.getColumnDimension(), m, "n");
305 }
306
307 final int nColB = b.getColumnDimension();
308 double[][] x = b.getData();
309
310
311 for (int j = 0; j < m; j++) {
312 final double[] lJ = lTData[j];
313 final double lJJ = lJ[j];
314 final double[] xJ = x[j];
315 for (int k = 0; k < nColB; ++k) {
316 xJ[k] /= lJJ;
317 }
318 for (int i = j + 1; i < m; i++) {
319 final double[] xI = x[i];
320 final double lJI = lJ[i];
321 for (int k = 0; k < nColB; ++k) {
322 xI[k] -= xJ[k] * lJI;
323 }
324 }
325 }
326
327
328 for (int j = m - 1; j >= 0; j--) {
329 final double lJJ = lTData[j][j];
330 final double[] xJ = x[j];
331 for (int k = 0; k < nColB; ++k) {
332 xJ[k] /= lJJ;
333 }
334 for (int i = 0; i < j; i++) {
335 final double[] xI = x[i];
336 final double lIJ = lTData[i][j];
337 for (int k = 0; k < nColB; ++k) {
338 xI[k] -= xJ[k] * lIJ;
339 }
340 }
341 }
342
343 return new Array2DRowRealMatrix(x, false);
344
345 }
346
347
348 public RealMatrix getInverse() throws InvalidMatrixException {
349 return solve(MatrixUtils.createRealIdentityMatrix(lTData.length));
350 }
351
352 }
353
354 }