package org.apache.commons.math3.ml.neuralnet.twod;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
import org.apache.commons.math3.ml.neuralnet.Network;
import org.apache.commons.math3.ml.neuralnet.Neuron;
import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2DTest.class */
public class NeuronSquareMesh2DTest {
    final FeatureInitializer init = FeatureInitializerFactory.uniform(0.0d, 2.0d);

    @Test(expected = NumberIsTooSmallException.class)
    public void testMinimalNetworkSize1() {
        new NeuronSquareMesh2D(1, false, 2, false, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init});
    }

    @Test(expected = NumberIsTooSmallException.class)
    public void testMinimalNetworkSize2() {
        new NeuronSquareMesh2D(2, false, 0, false, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init});
    }

    @Test
    public void testGetFeaturesSize() {
        Assert.assertEquals(3L, new NeuronSquareMesh2D(2, false, 2, false, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init, this.init, this.init}).getNetwork().getFeaturesSize());
    }

    @Test
    public void test2x2Network() {
        Network network = new NeuronSquareMesh2D(2, false, 2, false, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init}).getNetwork();
        for (long j : new long[]{0, 3}) {
            Collection neighbours = network.getNeighbours(network.getNeuron(j));
            for (long j2 : new long[]{1, 2}) {
                Assert.assertTrue(neighbours.contains(network.getNeuron(j2)));
            }
            Assert.assertEquals(2L, neighbours.size());
        }
        for (long j3 : new long[]{1, 2}) {
            Collection neighbours2 = network.getNeighbours(network.getNeuron(j3));
            for (long j4 : new long[]{0, 3}) {
                Assert.assertTrue(neighbours2.contains(network.getNeuron(j4)));
            }
            Assert.assertEquals(2L, neighbours2.size());
        }
    }

    @Test
    public void test2x2Network2() {
        Network network = new NeuronSquareMesh2D(2, false, 2, false, SquareNeighbourhood.MOORE, new FeatureInitializer[]{this.init}).getNetwork();
        for (long j : new long[]{0, 1, 2, 3}) {
            Collection neighbours = network.getNeighbours(network.getNeuron(j));
            for (long j2 : new long[]{0, 1, 2, 3}) {
                if (j != j2) {
                    Assert.assertTrue(neighbours.contains(network.getNeuron(j2)));
                }
            }
        }
    }

    @Test
    public void test3x2CylinderNetwork() {
        Network network = new NeuronSquareMesh2D(2, false, 3, true, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init}).getNetwork();
        Collection neighbours = network.getNeighbours(network.getNeuron(0L));
        for (long j : new long[]{1, 2, 3}) {
            Assert.assertTrue(neighbours.contains(network.getNeuron(j)));
        }
        Assert.assertEquals(3L, neighbours.size());
        Collection neighbours2 = network.getNeighbours(network.getNeuron(1L));
        for (long j2 : new long[]{0, 2, 4}) {
            Assert.assertTrue(neighbours2.contains(network.getNeuron(j2)));
        }
        Assert.assertEquals(3L, neighbours2.size());
        Collection neighbours3 = network.getNeighbours(network.getNeuron(2L));
        for (long j3 : new long[]{0, 1, 5}) {
            Assert.assertTrue(neighbours3.contains(network.getNeuron(j3)));
        }
        Assert.assertEquals(3L, neighbours3.size());
        Collection neighbours4 = network.getNeighbours(network.getNeuron(3L));
        for (long j4 : new long[]{0, 4, 5}) {
            Assert.assertTrue(neighbours4.contains(network.getNeuron(j4)));
        }
        Assert.assertEquals(3L, neighbours4.size());
        Collection neighbours5 = network.getNeighbours(network.getNeuron(4L));
        for (long j5 : new long[]{1, 3, 5}) {
            Assert.assertTrue(neighbours5.contains(network.getNeuron(j5)));
        }
        Assert.assertEquals(3L, neighbours5.size());
        Collection neighbours6 = network.getNeighbours(network.getNeuron(5L));
        for (long j6 : new long[]{2, 3, 4}) {
            Assert.assertTrue(neighbours6.contains(network.getNeuron(j6)));
        }
        Assert.assertEquals(3L, neighbours6.size());
    }

    @Test
    public void test3x2CylinderNetwork2() {
        Network network = new NeuronSquareMesh2D(2, false, 3, true, SquareNeighbourhood.MOORE, new FeatureInitializer[]{this.init}).getNetwork();
        for (long j : new long[]{0, 1, 2, 3, 4, 5}) {
            Collection neighbours = network.getNeighbours(network.getNeuron(j));
            for (long j2 : new long[]{0, 1, 2, 3, 4, 5}) {
                if (j != j2) {
                    Assert.assertTrue("id=" + j + " nId=" + j2, neighbours.contains(network.getNeuron(j2)));
                }
            }
        }
    }

    @Test
    public void test3x3TorusNetwork() {
        Network network = new NeuronSquareMesh2D(3, true, 3, true, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init}).getNetwork();
        Collection neighbours = network.getNeighbours(network.getNeuron(0L));
        for (long j : new long[]{1, 2, 3, 6}) {
            Assert.assertTrue(neighbours.contains(network.getNeuron(j)));
        }
        Assert.assertEquals(4L, neighbours.size());
        Collection neighbours2 = network.getNeighbours(network.getNeuron(1L));
        for (long j2 : new long[]{0, 2, 4, 7}) {
            Assert.assertTrue(neighbours2.contains(network.getNeuron(j2)));
        }
        Assert.assertEquals(4L, neighbours2.size());
        Collection neighbours3 = network.getNeighbours(network.getNeuron(2L));
        for (long j3 : new long[]{0, 1, 5, 8}) {
            Assert.assertTrue(neighbours3.contains(network.getNeuron(j3)));
        }
        Assert.assertEquals(4L, neighbours3.size());
        Collection neighbours4 = network.getNeighbours(network.getNeuron(3L));
        for (long j4 : new long[]{0, 4, 5, 6}) {
            Assert.assertTrue(neighbours4.contains(network.getNeuron(j4)));
        }
        Assert.assertEquals(4L, neighbours4.size());
        Collection neighbours5 = network.getNeighbours(network.getNeuron(4L));
        for (long j5 : new long[]{1, 3, 5, 7}) {
            Assert.assertTrue(neighbours5.contains(network.getNeuron(j5)));
        }
        Assert.assertEquals(4L, neighbours5.size());
        Collection neighbours6 = network.getNeighbours(network.getNeuron(5L));
        for (long j6 : new long[]{2, 3, 4, 8}) {
            Assert.assertTrue(neighbours6.contains(network.getNeuron(j6)));
        }
        Assert.assertEquals(4L, neighbours6.size());
        Collection neighbours7 = network.getNeighbours(network.getNeuron(6L));
        for (long j7 : new long[]{0, 3, 7, 8}) {
            Assert.assertTrue(neighbours7.contains(network.getNeuron(j7)));
        }
        Assert.assertEquals(4L, neighbours7.size());
        Collection neighbours8 = network.getNeighbours(network.getNeuron(7L));
        for (long j8 : new long[]{1, 4, 6, 8}) {
            Assert.assertTrue(neighbours8.contains(network.getNeuron(j8)));
        }
        Assert.assertEquals(4L, neighbours8.size());
        Collection neighbours9 = network.getNeighbours(network.getNeuron(8L));
        for (long j9 : new long[]{2, 5, 6, 7}) {
            Assert.assertTrue(neighbours9.contains(network.getNeuron(j9)));
        }
        Assert.assertEquals(4L, neighbours9.size());
    }

    @Test
    public void test3x3TorusNetwork2() {
        Network network = new NeuronSquareMesh2D(3, true, 3, true, SquareNeighbourhood.MOORE, new FeatureInitializer[]{this.init}).getNetwork();
        for (long j : new long[]{0, 1, 2, 3, 4, 5, 6, 7, 8}) {
            Collection neighbours = network.getNeighbours(network.getNeuron(j));
            for (long j2 : new long[]{0, 1, 2, 3, 4, 5, 6, 7, 8}) {
                if (j != j2) {
                    Assert.assertTrue("id=" + j + " nId=" + j2, neighbours.contains(network.getNeuron(j2)));
                }
            }
        }
    }

    @Test
    public void test3x3CylinderNetwork() {
        Network network = new NeuronSquareMesh2D(3, false, 3, true, SquareNeighbourhood.MOORE, new FeatureInitializer[]{this.init}).getNetwork();
        Collection neighbours = network.getNeighbours(network.getNeuron(0L));
        for (long j : new long[]{1, 2, 3, 4, 5}) {
            Assert.assertTrue(neighbours.contains(network.getNeuron(j)));
        }
        Assert.assertEquals(5L, neighbours.size());
        Collection neighbours2 = network.getNeighbours(network.getNeuron(1L));
        for (long j2 : new long[]{0, 2, 3, 4, 5}) {
            Assert.assertTrue(neighbours2.contains(network.getNeuron(j2)));
        }
        Assert.assertEquals(5L, neighbours2.size());
        Collection neighbours3 = network.getNeighbours(network.getNeuron(2L));
        for (long j3 : new long[]{0, 1, 3, 4, 5}) {
            Assert.assertTrue(neighbours3.contains(network.getNeuron(j3)));
        }
        Assert.assertEquals(5L, neighbours3.size());
        Collection neighbours4 = network.getNeighbours(network.getNeuron(3L));
        for (long j4 : new long[]{0, 1, 2, 4, 5, 6, 7, 8}) {
            Assert.assertTrue(neighbours4.contains(network.getNeuron(j4)));
        }
        Assert.assertEquals(8L, neighbours4.size());
        Collection neighbours5 = network.getNeighbours(network.getNeuron(4L));
        for (long j5 : new long[]{0, 1, 2, 3, 5, 6, 7, 8}) {
            Assert.assertTrue(neighbours5.contains(network.getNeuron(j5)));
        }
        Assert.assertEquals(8L, neighbours5.size());
        Collection neighbours6 = network.getNeighbours(network.getNeuron(5L));
        for (long j6 : new long[]{0, 1, 2, 3, 4, 6, 7, 8}) {
            Assert.assertTrue(neighbours6.contains(network.getNeuron(j6)));
        }
        Assert.assertEquals(8L, neighbours6.size());
        Collection neighbours7 = network.getNeighbours(network.getNeuron(6L));
        for (long j7 : new long[]{3, 4, 5, 7, 8}) {
            Assert.assertTrue(neighbours7.contains(network.getNeuron(j7)));
        }
        Assert.assertEquals(5L, neighbours7.size());
        Collection neighbours8 = network.getNeighbours(network.getNeuron(7L));
        for (long j8 : new long[]{3, 4, 5, 6, 8}) {
            Assert.assertTrue(neighbours8.contains(network.getNeuron(j8)));
        }
        Assert.assertEquals(5L, neighbours8.size());
        Collection neighbours9 = network.getNeighbours(network.getNeuron(8L));
        for (long j9 : new long[]{3, 4, 5, 6, 7}) {
            Assert.assertTrue(neighbours9.contains(network.getNeuron(j9)));
        }
        Assert.assertEquals(5L, neighbours9.size());
    }

    @Test
    public void test3x3CylinderNetwork2() {
        Network network = new NeuronSquareMesh2D(3, false, 3, false, SquareNeighbourhood.MOORE, new FeatureInitializer[]{this.init}).getNetwork();
        Collection neighbours = network.getNeighbours(network.getNeuron(0L));
        for (long j : new long[]{1, 3, 4}) {
            Assert.assertTrue(neighbours.contains(network.getNeuron(j)));
        }
        Assert.assertEquals(3L, neighbours.size());
        Collection neighbours2 = network.getNeighbours(network.getNeuron(1L));
        for (long j2 : new long[]{0, 2, 3, 4, 5}) {
            Assert.assertTrue(neighbours2.contains(network.getNeuron(j2)));
        }
        Assert.assertEquals(5L, neighbours2.size());
        Collection neighbours3 = network.getNeighbours(network.getNeuron(2L));
        for (long j3 : new long[]{1, 4, 5}) {
            Assert.assertTrue(neighbours3.contains(network.getNeuron(j3)));
        }
        Assert.assertEquals(3L, neighbours3.size());
        Collection neighbours4 = network.getNeighbours(network.getNeuron(3L));
        for (long j4 : new long[]{0, 1, 4, 6, 7}) {
            Assert.assertTrue(neighbours4.contains(network.getNeuron(j4)));
        }
        Assert.assertEquals(5L, neighbours4.size());
        Collection neighbours5 = network.getNeighbours(network.getNeuron(4L));
        for (long j5 : new long[]{0, 1, 2, 3, 5, 6, 7, 8}) {
            Assert.assertTrue(neighbours5.contains(network.getNeuron(j5)));
        }
        Assert.assertEquals(8L, neighbours5.size());
        Collection neighbours6 = network.getNeighbours(network.getNeuron(5L));
        for (long j6 : new long[]{1, 2, 4, 7, 8}) {
            Assert.assertTrue(neighbours6.contains(network.getNeuron(j6)));
        }
        Assert.assertEquals(5L, neighbours6.size());
        Collection neighbours7 = network.getNeighbours(network.getNeuron(6L));
        for (long j7 : new long[]{3, 4, 7}) {
            Assert.assertTrue(neighbours7.contains(network.getNeuron(j7)));
        }
        Assert.assertEquals(3L, neighbours7.size());
        Collection neighbours8 = network.getNeighbours(network.getNeuron(7L));
        for (long j8 : new long[]{3, 4, 5, 6, 8}) {
            Assert.assertTrue(neighbours8.contains(network.getNeuron(j8)));
        }
        Assert.assertEquals(5L, neighbours8.size());
        Collection neighbours9 = network.getNeighbours(network.getNeuron(8L));
        for (long j9 : new long[]{4, 5, 7}) {
            Assert.assertTrue(neighbours9.contains(network.getNeuron(j9)));
        }
        Assert.assertEquals(3L, neighbours9.size());
    }

    @Test
    public void testConcentricNeighbourhood() {
        Network network = new NeuronSquareMesh2D(5, true, 5, true, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init}).getNetwork();
        HashSet hashSet = new HashSet();
        Collection neighbours = network.getNeighbours(network.getNeuron(12L));
        for (long j : new long[]{7, 11, 13, 17}) {
            Assert.assertTrue(neighbours.contains(network.getNeuron(j)));
        }
        Assert.assertEquals(4L, neighbours.size());
        hashSet.add(network.getNeuron(12L));
        hashSet.addAll(neighbours);
        Collection neighbours2 = network.getNeighbours(neighbours, hashSet);
        for (long j2 : new long[]{6, 8, 16, 18, 2, 10, 14, 22}) {
            Assert.assertTrue(neighbours2.contains(network.getNeuron(j2)));
        }
        Assert.assertEquals(8L, neighbours2.size());
    }

    @Test
    public void testConcentricNeighbourhood2() {
        Network network = new NeuronSquareMesh2D(5, true, 5, true, SquareNeighbourhood.MOORE, new FeatureInitializer[]{this.init}).getNetwork();
        HashSet hashSet = new HashSet();
        Collection neighbours = network.getNeighbours(network.getNeuron(8L));
        for (long j : new long[]{2, 3, 4, 7, 9, 12, 13, 14}) {
            Assert.assertTrue(neighbours.contains(network.getNeuron(j)));
        }
        Assert.assertEquals(8L, neighbours.size());
        hashSet.add(network.getNeuron(8L));
        hashSet.addAll(neighbours);
        Collection neighbours2 = network.getNeighbours(neighbours, hashSet);
        for (long j2 : new long[]{1, 6, 11, 16, 17, 18, 19, 15, 10, 5, 0, 20, 24, 23, 22, 21}) {
            Assert.assertTrue(neighbours2.contains(network.getNeuron(j2)));
        }
        Assert.assertEquals(16L, neighbours2.size());
    }

    @Test
    public void testSerialize() throws IOException, ClassNotFoundException {
        NeuronSquareMesh2D neuronSquareMesh2D = new NeuronSquareMesh2D(4, false, 3, true, SquareNeighbourhood.VON_NEUMANN, new FeatureInitializer[]{this.init});
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        new ObjectOutputStream(byteArrayOutputStream).writeObject(neuronSquareMesh2D);
        NeuronSquareMesh2D neuronSquareMesh2D2 = (NeuronSquareMesh2D) new ObjectInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray())).readObject();
        Iterator it = neuronSquareMesh2D.getNetwork().iterator();
        while (it.hasNext()) {
            Neuron neuron = (Neuron) it.next();
            Neuron neuron2 = neuronSquareMesh2D2.getNetwork().getNeuron(neuron.getIdentifier());
            double[] features = neuron.getFeatures();
            double[] features2 = neuron2.getFeatures();
            Assert.assertEquals(features.length, features2.length);
            for (int i = 0; i < features.length; i++) {
                Assert.assertEquals(features[i], features2[i], 0.0d);
            }
            Collection neighbours = neuronSquareMesh2D.getNetwork().getNeighbours(neuron);
            Collection neighbours2 = neuronSquareMesh2D2.getNetwork().getNeighbours(neuron2);
            Assert.assertEquals(neighbours.size(), neighbours2.size());
            Iterator it2 = neighbours.iterator();
            while (it2.hasNext()) {
                Assert.assertTrue(neighbours2.contains(neuronSquareMesh2D2.getNetwork().getNeuron(((Neuron) it2.next()).getIdentifier())));
            }
        }
    }
}
