View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math.analysis.interpolation;
18  
19  import java.io.Serializable;
20  import java.util.Arrays;
21  
22  import org.apache.commons.math.MathException;
23  import org.apache.commons.math.analysis.polynomials.PolynomialSplineFunction;
24  
25  /**
26   * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
27   * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
28   * real univariate functions.
29   * <p/>
30   * For reference, see
31   * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf">
32   * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
33   * Scatterplots</a>
34   * <p/>
35   * This class implements both the loess method and serves as an interpolation
36   * adapter to it, allowing to build a spline on the obtained loess fit.
37   *
38   * @version $Revision: 794709 $ $Date: 2009-07-16 11:09:02 -0400 (Thu, 16 Jul 2009) $
39   * @since 2.0
40   */
41  public class LoessInterpolator
42          implements UnivariateRealInterpolator, Serializable {
43  
44      /** serializable version identifier. */
45      private static final long serialVersionUID = 5204927143605193821L;
46  
47      /**
48       * Default value of the bandwidth parameter.
49       */
50      public static final double DEFAULT_BANDWIDTH = 0.3;
51      /**
52       * Default value of the number of robustness iterations.
53       */
54      public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
55  
56      /**
57       * The bandwidth parameter: when computing the loess fit at
58       * a particular point, this fraction of source points closest
59       * to the current point is taken into account for computing
60       * a least-squares regression.
61       * <p/>
62       * A sensible value is usually 0.25 to 0.5.
63       */
64      private final double bandwidth;
65  
66      /**
67       * The number of robustness iterations parameter: this many
68       * robustness iterations are done.
69       * <p/>
70       * A sensible value is usually 0 (just the initial fit without any
71       * robustness iterations) to 4.
72       */
73      private final int robustnessIters;
74  
75      /**
76       * Constructs a new {@link LoessInterpolator}
77       * with a bandwidth of {@link #DEFAULT_BANDWIDTH} and
78       * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations.
79       * See {@link #LoessInterpolator(double, int)} for an explanation of
80       * the parameters.
81       */
82      public LoessInterpolator() {
83          this.bandwidth = DEFAULT_BANDWIDTH;
84          this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
85      }
86  
87      /**
88       * Constructs a new {@link LoessInterpolator}
89       * with given bandwidth and number of robustness iterations.
90       *
91       * @param bandwidth  when computing the loess fit at
92       * a particular point, this fraction of source points closest
93       * to the current point is taken into account for computing
94       * a least-squares regression.</br>
95       * A sensible value is usually 0.25 to 0.5, the default value is
96       * {@link #DEFAULT_BANDWIDTH}.
97       * @param robustnessIters This many robustness iterations are done.</br>
98       * A sensible value is usually 0 (just the initial fit without any
99       * robustness iterations) to 4, the default value is
100      * {@link #DEFAULT_ROBUSTNESS_ITERS}.
101      * @throws MathException if bandwidth does not lie in the interval [0,1]
102      * or if robustnessIters is negative.
103      */
104     public LoessInterpolator(double bandwidth, int robustnessIters) throws MathException {
105         if (bandwidth < 0 || bandwidth > 1) {
106             throw new MathException("bandwidth must be in the interval [0,1], but got {0}",
107                                     bandwidth);
108         }
109         this.bandwidth = bandwidth;
110         if (robustnessIters < 0) {
111             throw new MathException("the number of robustness iterations must " +
112                                     "be non-negative, but got {0}",
113                                     robustnessIters);
114         }
115         this.robustnessIters = robustnessIters;
116     }
117 
118     /**
119      * Compute an interpolating function by performing a loess fit
120      * on the data at the original abscissae and then building a cubic spline
121      * with a
122      * {@link org.apache.commons.math.analysis.interpolation.SplineInterpolator}
123      * on the resulting fit.
124      *
125      * @param xval the arguments for the interpolation points
126      * @param yval the values for the interpolation points
127      * @return A cubic spline built upon a loess fit to the data at the original abscissae
128      * @throws MathException  if some of the following conditions are false:
129      * <ul>
130      * <li> Arguments and values are of the same size that is greater than zero</li>
131      * <li> The arguments are in a strictly increasing order</li>
132      * <li> All arguments and values are finite real numbers</li>
133      * </ul>
134      */
135     public final PolynomialSplineFunction interpolate(
136             final double[] xval, final double[] yval) throws MathException {
137         return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
138     }
139 
140     /**
141      * Compute a loess fit on the data at the original abscissae.
142      *
143      * @param xval the arguments for the interpolation points
144      * @param yval the values for the interpolation points
145      * @return values of the loess fit at corresponding original abscissae
146      * @throws MathException if some of the following conditions are false:
147      * <ul>
148      * <li> Arguments and values are of the same size that is greater than zero</li>
149      * <li> The arguments are in a strictly increasing order</li>
150      * <li> All arguments and values are finite real numbers</li>
151      * </ul>
152      */
153     public final double[] smooth(final double[] xval, final double[] yval)
154             throws MathException {
155         if (xval.length != yval.length) {
156             throw new MathException(
157                     "Loess expects the abscissa and ordinate arrays " +
158                     "to be of the same size, " +
159                     "but got {0} abscisssae and {1} ordinatae",
160                     xval.length, yval.length);
161         }
162 
163         final int n = xval.length;
164 
165         if (n == 0) {
166             throw new MathException("Loess expects at least 1 point");
167         }
168 
169         checkAllFiniteReal(xval, true);
170         checkAllFiniteReal(yval, false);
171 
172         checkStrictlyIncreasing(xval);
173 
174         if (n == 1) {
175             return new double[]{yval[0]};
176         }
177 
178         if (n == 2) {
179             return new double[]{yval[0], yval[1]};
180         }
181 
182         int bandwidthInPoints = (int) (bandwidth * n);
183 
184         if (bandwidthInPoints < 2) {
185             throw new MathException(
186                     "the bandwidth must be large enough to " +
187                     "accomodate at least 2 points. There are {0} " +
188                     " data points, and bandwidth must be at least {1} " +
189                     " but it is only {2}",
190                     n, 2.0 / n, bandwidth);
191         }
192 
193         final double[] res = new double[n];
194 
195         final double[] residuals = new double[n];
196         final double[] sortedResiduals = new double[n];
197 
198         final double[] robustnessWeights = new double[n];
199 
200         // Do an initial fit and 'robustnessIters' robustness iterations.
201         // This is equivalent to doing 'robustnessIters+1' robustness iterations
202         // starting with all robustness weights set to 1.
203         Arrays.fill(robustnessWeights, 1);
204 
205         for (int iter = 0; iter <= robustnessIters; ++iter) {
206             final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
207             // At each x, compute a local weighted linear regression
208             for (int i = 0; i < n; ++i) {
209                 final double x = xval[i];
210 
211                 // Find out the interval of source points on which
212                 // a regression is to be made.
213                 if (i > 0) {
214                     updateBandwidthInterval(xval, i, bandwidthInterval);
215                 }
216 
217                 final int ileft = bandwidthInterval[0];
218                 final int iright = bandwidthInterval[1];
219 
220                 // Compute the point of the bandwidth interval that is
221                 // farthest from x
222                 final int edge;
223                 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
224                     edge = ileft;
225                 } else {
226                     edge = iright;
227                 }
228 
229                 // Compute a least-squares linear fit weighted by
230                 // the product of robustness weights and the tricube
231                 // weight function.
232                 // See http://en.wikipedia.org/wiki/Linear_regression
233                 // (section "Univariate linear case")
234                 // and http://en.wikipedia.org/wiki/Weighted_least_squares
235                 // (section "Weighted least squares")
236                 double sumWeights = 0;
237                 double sumX = 0, sumXSquared = 0, sumY = 0, sumXY = 0;
238                 double denom = Math.abs(1.0 / (xval[edge] - x));
239                 for (int k = ileft; k <= iright; ++k) {
240                     final double xk = xval[k];
241                     final double yk = yval[k];
242                     double dist;
243                     if (k < i) {
244                         dist = (x - xk);
245                     } else {
246                         dist = (xk - x);
247                     }
248                     final double w = tricube(dist * denom) * robustnessWeights[k];
249                     final double xkw = xk * w;
250                     sumWeights += w;
251                     sumX += xkw;
252                     sumXSquared += xk * xkw;
253                     sumY += yk * w;
254                     sumXY += yk * xkw;
255                 }
256 
257                 final double meanX = sumX / sumWeights;
258                 final double meanY = sumY / sumWeights;
259                 final double meanXY = sumXY / sumWeights;
260                 final double meanXSquared = sumXSquared / sumWeights;
261 
262                 final double beta;
263                 if (meanXSquared == meanX * meanX) {
264                     beta = 0;
265                 } else {
266                     beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
267                 }
268 
269                 final double alpha = meanY - beta * meanX;
270 
271                 res[i] = beta * x + alpha;
272                 residuals[i] = Math.abs(yval[i] - res[i]);
273             }
274 
275             // No need to recompute the robustness weights at the last
276             // iteration, they won't be needed anymore
277             if (iter == robustnessIters) {
278                 break;
279             }
280 
281             // Recompute the robustness weights.
282 
283             // Find the median residual.
284             // An arraycopy and a sort are completely tractable here, 
285             // because the preceding loop is a lot more expensive
286             System.arraycopy(residuals, 0, sortedResiduals, 0, n);
287             Arrays.sort(sortedResiduals);
288             final double medianResidual = sortedResiduals[n / 2];
289 
290             if (medianResidual == 0) {
291                 break;
292             }
293 
294             for (int i = 0; i < n; ++i) {
295                 final double arg = residuals[i] / (6 * medianResidual);
296                 robustnessWeights[i] = (arg >= 1) ? 0 : Math.pow(1 - arg * arg, 2);
297             }
298         }
299 
300         return res;
301     }
302 
303     /**
304      * Given an index interval into xval that embraces a certain number of
305      * points closest to xval[i-1], update the interval so that it embraces
306      * the same number of points closest to xval[i]
307      *
308      * @param xval arguments array
309      * @param i the index around which the new interval should be computed
310      * @param bandwidthInterval a two-element array {left, right} such that: <p/>
311      * <tt>(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])</tt>
312      * <p/> and also <p/>
313      * <tt>(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])</tt>.
314      * The array will be updated.
315      */
316     private static void updateBandwidthInterval(final double[] xval, final int i,
317                                                 final int[] bandwidthInterval) {
318         final int left = bandwidthInterval[0];
319         final int right = bandwidthInterval[1];
320 
321         // The right edge should be adjusted if the next point to the right
322         // is closer to xval[i] than the leftmost point of the current interval
323         if (right < xval.length - 1 &&
324            xval[right+1] - xval[i] < xval[i] - xval[left]) {
325             bandwidthInterval[0]++;
326             bandwidthInterval[1]++;
327         }
328     }
329 
330     /**
331      * Compute the 
332      * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
333      * weight function
334      *
335      * @param x the argument
336      * @return (1-|x|^3)^3
337      */
338     private static double tricube(final double x) {
339         final double tmp = 1 - x * x * x;
340         return tmp * tmp * tmp;
341     }
342 
343     /**
344      * Check that all elements of an array are finite real numbers.
345      *
346      * @param values the values array
347      * @param isAbscissae if true, elements are abscissae otherwise they are ordinatae
348      * @throws MathException if one of the values is not
349      *         a finite real number
350      */
351     private static void checkAllFiniteReal(final double[] values, final boolean isAbscissae)
352         throws MathException {
353         for (int i = 0; i < values.length; i++) {
354             final double x = values[i];
355             if (Double.isInfinite(x) || Double.isNaN(x)) {
356                 final String pattern = isAbscissae ?
357                         "all abscissae must be finite real numbers, but {0}-th is {1}" :
358                         "all ordinatae must be finite real numbers, but {0}-th is {1}";
359                 throw new MathException(pattern, i, x);
360             }
361         }
362     }
363 
364     /**
365      * Check that elements of the abscissae array are in a strictly
366      * increasing order.
367      *
368      * @param xval the abscissae array
369      * @throws MathException if the abscissae array
370      * is not in a strictly increasing order
371      */
372     private static void checkStrictlyIncreasing(final double[] xval)
373         throws MathException {
374         for (int i = 0; i < xval.length; ++i) {
375             if (i >= 1 && xval[i - 1] >= xval[i]) {
376                 throw new MathException(
377                         "the abscissae array must be sorted in a strictly " +
378                         "increasing order, but the {0}-th element is {1} " +
379                         "whereas {2}-th is {3}",
380                         i - 1, xval[i - 1], i, xval[i]);
381             }
382         }
383     }
384 }