1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.optimization.direct;
19
20 import java.util.Arrays;
21 import java.util.Comparator;
22
23 import org.apache.commons.math.FunctionEvaluationException;
24 import org.apache.commons.math.MathRuntimeException;
25 import org.apache.commons.math.MaxEvaluationsExceededException;
26 import org.apache.commons.math.MaxIterationsExceededException;
27 import org.apache.commons.math.analysis.MultivariateRealFunction;
28 import org.apache.commons.math.optimization.GoalType;
29 import org.apache.commons.math.optimization.MultivariateRealOptimizer;
30 import org.apache.commons.math.optimization.OptimizationException;
31 import org.apache.commons.math.optimization.RealConvergenceChecker;
32 import org.apache.commons.math.optimization.RealPointValuePair;
33 import org.apache.commons.math.optimization.SimpleScalarValueChecker;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88 public abstract class DirectSearchOptimizer implements MultivariateRealOptimizer {
89
90
91 protected RealPointValuePair[] simplex;
92
93
94 private MultivariateRealFunction f;
95
96
97 private RealConvergenceChecker checker;
98
99
100 private int maxIterations;
101
102
103 private int iterations;
104
105
106 private int maxEvaluations;
107
108
109 private int evaluations;
110
111
112 private double[][] startConfiguration;
113
114
115
116 protected DirectSearchOptimizer() {
117 setConvergenceChecker(new SimpleScalarValueChecker());
118 setMaxIterations(Integer.MAX_VALUE);
119 setMaxEvaluations(Integer.MAX_VALUE);
120 }
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138 public void setStartConfiguration(final double[] steps)
139 throws IllegalArgumentException {
140
141
142 final int n = steps.length;
143 startConfiguration = new double[n][n];
144 for (int i = 0; i < n; ++i) {
145 final double[] vertexI = startConfiguration[i];
146 for (int j = 0; j < i + 1; ++j) {
147 if (steps[j] == 0.0) {
148 throw MathRuntimeException.createIllegalArgumentException(
149 "equals vertices {0} and {1} in simplex configuration",
150 j, j + 1);
151 }
152 System.arraycopy(steps, 0, vertexI, 0, j + 1);
153 }
154 }
155 }
156
157
158
159
160
161
162
163
164
165
166 public void setStartConfiguration(final double[][] referenceSimplex)
167 throws IllegalArgumentException {
168
169
170
171 final int n = referenceSimplex.length - 1;
172 if (n < 0) {
173 throw MathRuntimeException.createIllegalArgumentException(
174 "simplex must contain at least one point");
175 }
176 startConfiguration = new double[n][n];
177 final double[] ref0 = referenceSimplex[0];
178
179
180 for (int i = 0; i < n + 1; ++i) {
181
182 final double[] refI = referenceSimplex[i];
183
184
185 if (refI.length != n) {
186 throw MathRuntimeException.createIllegalArgumentException(
187 "dimension mismatch {0} != {1}",
188 refI.length, n);
189 }
190 for (int j = 0; j < i; ++j) {
191 final double[] refJ = referenceSimplex[j];
192 boolean allEquals = true;
193 for (int k = 0; k < n; ++k) {
194 if (refI[k] != refJ[k]) {
195 allEquals = false;
196 break;
197 }
198 }
199 if (allEquals) {
200 throw MathRuntimeException.createIllegalArgumentException(
201 "equals vertices {0} and {1} in simplex configuration",
202 i, j);
203 }
204 }
205
206
207 if (i > 0) {
208 final double[] confI = startConfiguration[i - 1];
209 for (int k = 0; k < n; ++k) {
210 confI[k] = refI[k] - ref0[k];
211 }
212 }
213
214 }
215
216 }
217
218
219 public void setMaxIterations(int maxIterations) {
220 this.maxIterations = maxIterations;
221 }
222
223
224 public int getMaxIterations() {
225 return maxIterations;
226 }
227
228
229 public void setMaxEvaluations(int maxEvaluations) {
230 this.maxEvaluations = maxEvaluations;
231 }
232
233
234 public int getMaxEvaluations() {
235 return maxEvaluations;
236 }
237
238
239 public int getIterations() {
240 return iterations;
241 }
242
243
244 public int getEvaluations() {
245 return evaluations;
246 }
247
248
249 public void setConvergenceChecker(RealConvergenceChecker checker) {
250 this.checker = checker;
251 }
252
253
254 public RealConvergenceChecker getConvergenceChecker() {
255 return checker;
256 }
257
258
259 public RealPointValuePair optimize(final MultivariateRealFunction f,
260 final GoalType goalType,
261 final double[] startPoint)
262 throws FunctionEvaluationException, OptimizationException,
263 IllegalArgumentException {
264
265 if (startConfiguration == null) {
266
267
268 final double[] unit = new double[startPoint.length];
269 Arrays.fill(unit, 1.0);
270 setStartConfiguration(unit);
271 }
272
273 this.f = f;
274 final Comparator<RealPointValuePair> comparator =
275 new Comparator<RealPointValuePair>() {
276 public int compare(final RealPointValuePair o1,
277 final RealPointValuePair o2) {
278 final double v1 = o1.getValue();
279 final double v2 = o2.getValue();
280 return (goalType == GoalType.MINIMIZE) ?
281 Double.compare(v1, v2) : Double.compare(v2, v1);
282 }
283 };
284
285
286 iterations = 0;
287 evaluations = 0;
288 buildSimplex(startPoint);
289 evaluateSimplex(comparator);
290
291 RealPointValuePair[] previous = new RealPointValuePair[simplex.length];
292 while (true) {
293
294 if (iterations > 0) {
295 boolean converged = true;
296 for (int i = 0; i < simplex.length; ++i) {
297 converged &= checker.converged(iterations, previous[i], simplex[i]);
298 }
299 if (converged) {
300
301 return simplex[0];
302 }
303 }
304
305
306 System.arraycopy(simplex, 0, previous, 0, simplex.length);
307 iterateSimplex(comparator);
308
309 }
310
311 }
312
313
314
315
316
317 protected void incrementIterationsCounter()
318 throws OptimizationException {
319 if (++iterations > maxIterations) {
320 throw new OptimizationException(new MaxIterationsExceededException(maxIterations));
321 }
322 }
323
324
325
326
327
328
329
330
331 protected abstract void iterateSimplex(final Comparator<RealPointValuePair> comparator)
332 throws FunctionEvaluationException, OptimizationException, IllegalArgumentException;
333
334
335
336
337
338
339
340
341
342
343 protected double evaluate(final double[] x)
344 throws FunctionEvaluationException, IllegalArgumentException {
345 if (++evaluations > maxEvaluations) {
346 throw new FunctionEvaluationException(new MaxEvaluationsExceededException(maxEvaluations),
347 x);
348 }
349 return f.value(x);
350 }
351
352
353
354
355
356
357 private void buildSimplex(final double[] startPoint)
358 throws IllegalArgumentException {
359
360 final int n = startPoint.length;
361 if (n != startConfiguration.length) {
362 throw MathRuntimeException.createIllegalArgumentException(
363 "dimension mismatch {0} != {1}",
364 n, startConfiguration.length);
365 }
366
367
368 simplex = new RealPointValuePair[n + 1];
369 simplex[0] = new RealPointValuePair(startPoint, Double.NaN);
370
371
372 for (int i = 0; i < n; ++i) {
373 final double[] confI = startConfiguration[i];
374 final double[] vertexI = new double[n];
375 for (int k = 0; k < n; ++k) {
376 vertexI[k] = startPoint[k] + confI[k];
377 }
378 simplex[i + 1] = new RealPointValuePair(vertexI, Double.NaN);
379 }
380
381 }
382
383
384
385
386
387
388 protected void evaluateSimplex(final Comparator<RealPointValuePair> comparator)
389 throws FunctionEvaluationException, OptimizationException {
390
391
392 for (int i = 0; i < simplex.length; ++i) {
393 final RealPointValuePair vertex = simplex[i];
394 final double[] point = vertex.getPointRef();
395 if (Double.isNaN(vertex.getValue())) {
396 simplex[i] = new RealPointValuePair(point, evaluate(point), false);
397 }
398 }
399
400
401 Arrays.sort(simplex, comparator);
402
403 }
404
405
406
407
408
409 protected void replaceWorstPoint(RealPointValuePair pointValuePair,
410 final Comparator<RealPointValuePair> comparator) {
411 int n = simplex.length - 1;
412 for (int i = 0; i < n; ++i) {
413 if (comparator.compare(simplex[i], pointValuePair) > 0) {
414 RealPointValuePair tmp = simplex[i];
415 simplex[i] = pointValuePair;
416 pointValuePair = tmp;
417 }
418 }
419 simplex[n] = pointValuePair;
420 }
421
422 }