package org.jdmp.weka.clusterer;

import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.weka.clusterer.WekaClusterer;
import org.junit.Assert;
import org.junit.Test;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

/* loaded from: input_file:org/jdmp/weka/clusterer/TestWekaClusterer.class */
public class TestWekaClusterer {
    @Test
    public void testClusteringEM() throws Exception {
        ListDataSet IRIS = ListDataSet.Factory.IRIS();
        WekaClusterer wekaClusterer = new WekaClusterer(WekaClusterer.WekaClustererType.EM, false, new String[0]);
        wekaClusterer.setNumberOfClusters(3);
        wekaClusterer.train(IRIS);
        wekaClusterer.predict(IRIS);
        Matrix sum = IRIS.getPredictedMatrix().sum(Calculation.Ret.NEW, 0, true);
        Assert.assertEquals(50.0d, sum.getAsDouble(new long[]{0, 0}), 15.0d);
        Assert.assertEquals(50.0d, sum.getAsDouble(new long[]{0, 1}), 15.0d);
        Assert.assertEquals(50.0d, sum.getAsDouble(new long[]{0, 2}), 15.0d);
    }

    @Test
    public void testClusteringKMeans() throws Exception {
        ListDataSet IRIS = ListDataSet.Factory.IRIS();
        WekaClusterer wekaClusterer = new WekaClusterer(WekaClusterer.WekaClustererType.SimpleKMeans, false, new String[0]);
        wekaClusterer.setNumberOfClusters(3);
        wekaClusterer.train(IRIS);
        wekaClusterer.predict(IRIS);
        Matrix sum = IRIS.getPredictedMatrix().sum(Calculation.Ret.NEW, 0, true);
        Assert.assertEquals(50.0d, sum.getAsDouble(new long[]{0, 0}), 15.0d);
        Assert.assertEquals(50.0d, sum.getAsDouble(new long[]{0, 1}), 15.0d);
        Assert.assertEquals(50.0d, sum.getAsDouble(new long[]{0, 2}), 15.0d);
    }
}
