001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.stat.clustering;
019    
020    import static org.junit.Assert.assertEquals;
021    import static org.junit.Assert.assertTrue;
022    
023    import java.util.Arrays;
024    import java.util.List;
025    import java.util.Random;
026    
027    import org.junit.Test;
028    
029    public class KMeansPlusPlusClustererTest {
030    
031        @Test
032        public void dimension2() {
033            KMeansPlusPlusClusterer<EuclideanIntegerPoint> transformer =
034                new KMeansPlusPlusClusterer<EuclideanIntegerPoint>(new Random(1746432956321l));
035            EuclideanIntegerPoint[] points = new EuclideanIntegerPoint[] {
036    
037                    // first expected cluster
038                    new EuclideanIntegerPoint(new int[] { -15,  3 }),
039                    new EuclideanIntegerPoint(new int[] { -15,  4 }),
040                    new EuclideanIntegerPoint(new int[] { -15,  5 }),
041                    new EuclideanIntegerPoint(new int[] { -14,  3 }),
042                    new EuclideanIntegerPoint(new int[] { -14,  5 }),
043                    new EuclideanIntegerPoint(new int[] { -13,  3 }),
044                    new EuclideanIntegerPoint(new int[] { -13,  4 }),
045                    new EuclideanIntegerPoint(new int[] { -13,  5 }),
046    
047                    // second expected cluster
048                    new EuclideanIntegerPoint(new int[] { -1,  0 }),
049                    new EuclideanIntegerPoint(new int[] { -1, -1 }),
050                    new EuclideanIntegerPoint(new int[] {  0, -1 }),
051                    new EuclideanIntegerPoint(new int[] {  1, -1 }),
052                    new EuclideanIntegerPoint(new int[] {  1, -2 }),
053    
054                    // third expected cluster
055                    new EuclideanIntegerPoint(new int[] { 13,  3 }),
056                    new EuclideanIntegerPoint(new int[] { 13,  4 }),
057                    new EuclideanIntegerPoint(new int[] { 14,  4 }),
058                    new EuclideanIntegerPoint(new int[] { 14,  7 }),
059                    new EuclideanIntegerPoint(new int[] { 16,  5 }),
060                    new EuclideanIntegerPoint(new int[] { 16,  6 }),
061                    new EuclideanIntegerPoint(new int[] { 17,  4 }),
062                    new EuclideanIntegerPoint(new int[] { 17,  7 })
063    
064            };
065            List<Cluster<EuclideanIntegerPoint>> clusters =
066                transformer.cluster(Arrays.asList(points), 3, 10);
067    
068            assertEquals(3, clusters.size());
069            boolean cluster1Found = false;
070            boolean cluster2Found = false;
071            boolean cluster3Found = false;
072            for (Cluster<EuclideanIntegerPoint> cluster : clusters) {
073                int[] center = cluster.getCenter().getPoint();
074                if (center[0] < 0) {
075                    cluster1Found = true;
076                    assertEquals(8, cluster.getPoints().size());
077                    assertEquals(-14, center[0]);
078                    assertEquals( 4, center[1]);
079                } else if (center[1] < 0) {
080                    cluster2Found = true;
081                    assertEquals(5, cluster.getPoints().size());
082                    assertEquals( 0, center[0]);
083                    assertEquals(-1, center[1]);
084                } else {
085                    cluster3Found = true;
086                    assertEquals(8, cluster.getPoints().size());
087                    assertEquals(15, center[0]);
088                    assertEquals(5, center[1]);
089                }
090            }
091            assertTrue(cluster1Found);
092            assertTrue(cluster2Found);
093            assertTrue(cluster3Found);
094    
095        }
096    
097    }