1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization;
19  
20  import static org.junit.Assert.assertEquals;
21  import static org.junit.Assert.assertTrue;
22  
23  import java.awt.geom.Point2D;
24  import java.util.ArrayList;
25  
26  import org.apache.commons.math.FunctionEvaluationException;
27  import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction;
28  import org.apache.commons.math.analysis.MultivariateRealFunction;
29  import org.apache.commons.math.analysis.MultivariateVectorialFunction;
30  import org.apache.commons.math.analysis.solvers.BrentSolver;
31  import org.apache.commons.math.optimization.general.ConjugateGradientFormula;
32  import org.apache.commons.math.optimization.general.NonLinearConjugateGradientOptimizer;
33  import org.apache.commons.math.random.GaussianRandomGenerator;
34  import org.apache.commons.math.random.JDKRandomGenerator;
35  import org.apache.commons.math.random.RandomVectorGenerator;
36  import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
37  import org.junit.Test;
38  
39  public class MultiStartDifferentiableMultivariateRealOptimizerTest {
40  
41      @Test
42      public void testCircleFitting() throws FunctionEvaluationException, OptimizationException {
43          Circle circle = new Circle();
44          circle.addPoint( 30.0,  68.0);
45          circle.addPoint( 50.0,  -6.0);
46          circle.addPoint(110.0, -20.0);
47          circle.addPoint( 35.0,  15.0);
48          circle.addPoint( 45.0,  97.0);
49          NonLinearConjugateGradientOptimizer underlying =
50              new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
51          JDKRandomGenerator g = new JDKRandomGenerator();
52          g.setSeed(753289573253l);
53          RandomVectorGenerator generator =
54              new UncorrelatedRandomVectorGenerator(new double[] { 50.0, 50.0 }, new double[] { 10.0, 10.0 },
55                                                    new GaussianRandomGenerator(g));
56          MultiStartDifferentiableMultivariateRealOptimizer optimizer =
57              new MultiStartDifferentiableMultivariateRealOptimizer(underlying, 10, generator);
58          optimizer.setMaxIterations(100);
59          assertEquals(100, optimizer.getMaxIterations());
60          optimizer.setMaxEvaluations(100);
61          assertEquals(100, optimizer.getMaxEvaluations());
62          optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-10));
63          BrentSolver solver = new BrentSolver();
64          solver.setAbsoluteAccuracy(1.0e-13);
65          solver.setRelativeAccuracy(1.0e-15);
66          RealPointValuePair optimum =
67              optimizer.optimize(circle, GoalType.MINIMIZE, new double[] { 98.680, 47.345 });
68          RealPointValuePair[] optima = optimizer.getOptima();
69          for (RealPointValuePair o : optima) {
70              Point2D.Double center = new Point2D.Double(o.getPointRef()[0], o.getPointRef()[1]);
71              assertEquals(69.960161753, circle.getRadius(center), 1.0e-8);
72              assertEquals(96.075902096, center.x, 1.0e-8);
73              assertEquals(48.135167894, center.y, 1.0e-8);
74          }
75          assertTrue(optimizer.getGradientEvaluations() > 650);
76          assertTrue(optimizer.getGradientEvaluations() < 700);
77          assertTrue(optimizer.getEvaluations() > 70);
78          assertTrue(optimizer.getEvaluations() < 90);
79          assertTrue(optimizer.getIterations() > 70);
80          assertTrue(optimizer.getIterations() < 90);
81          assertEquals(3.1267527, optimum.getValue(), 1.0e-8);
82      }
83  
84      private static class Circle implements DifferentiableMultivariateRealFunction {
85  
86          private ArrayList<Point2D.Double> points;
87  
88          public Circle() {
89              points  = new ArrayList<Point2D.Double>();
90          }
91  
92          public void addPoint(double px, double py) {
93              points.add(new Point2D.Double(px, py));
94          }
95  
96          public double getRadius(Point2D.Double center) {
97              double r = 0;
98              for (Point2D.Double point : points) {
99                  r += point.distance(center);
100             }
101             return r / points.size();
102         }
103 
104         private double[] gradient(double[] point) {
105 
106             // optimal radius
107             Point2D.Double center = new Point2D.Double(point[0], point[1]);
108             double radius = getRadius(center);
109 
110             // gradient of the sum of squared residuals
111             double dJdX = 0;
112             double dJdY = 0;
113             for (Point2D.Double pk : points) {
114                 double dk = pk.distance(center);
115                 dJdX += (center.x - pk.x) * (dk - radius) / dk;
116                 dJdY += (center.y - pk.y) * (dk - radius) / dk;
117             }
118             dJdX *= 2;
119             dJdY *= 2;
120 
121             return new double[] { dJdX, dJdY };
122 
123         }
124 
125         public double value(double[] variables)
126         throws IllegalArgumentException, FunctionEvaluationException {
127 
128             Point2D.Double center = new Point2D.Double(variables[0], variables[1]);
129             double radius = getRadius(center);
130 
131             double sum = 0;
132             for (Point2D.Double point : points) {
133                 double di = point.distance(center) - radius;
134                 sum += di * di;
135             }
136 
137             return sum;
138 
139         }
140 
141         public MultivariateVectorialFunction gradient() {
142             return new MultivariateVectorialFunction() {
143                 public double[] value(double[] point) {
144                     return gradient(point);
145                 }
146             };
147         }
148 
149         public MultivariateRealFunction partialDerivative(final int k) {
150             return new MultivariateRealFunction() {
151                 public double value(double[] point) {
152                     return gradient(point)[k];
153                 }
154             };
155         }
156 
157     }
158 
159 }