diff --git a/LICENSE.txt b/LICENSE.txt
index e7d66177b..5fba36d0f 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -434,3 +434,8 @@ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
===============================================================================
+The initial commit of package "org.apache.commons.math3.ml.neuralnet" is
+an adapted version of code developed in the context of the Data Processing
+and Analysis Consortium (DPAC) of the "Gaia" project of the European Space
+Agency (ESA).
+===============================================================================
diff --git a/findbugs-exclude-filter.xml b/findbugs-exclude-filter.xml
index 4768bed66..a38621a55 100644
--- a/findbugs-exclude-filter.xml
+++ b/findbugs-exclude-filter.xml
@@ -363,4 +363,11 @@
+
+
+
+
+
+
+
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 5391a82c1..8322c0d01 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces!
+
+ Utilities for creating artificial neural networks (package "o.a.c.m.ml.neuralnet").
+ Implementation of Kohonen's Self-Organizing Feature Map (SOFM).
+
The cutOff mechanism of the "SimplexSolver" in package o.a.c.math3.optim.linear
could lead to invalid solutions. The mechanism has been improved in a way that
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java
new file mode 100644
index 000000000..06b6fbd67
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializer.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+/**
+ * Defines how to assign the first value of a neuron's feature.
+ *
+ * @version $Id$
+ */
+public interface FeatureInitializer {
+ /**
+ * Selects the initial value.
+ *
+ * @return the initial value.
+ */
+ double value();
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java
new file mode 100644
index 000000000..fc0e9a4d7
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/FeatureInitializerFactory.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.distribution.UniformRealDistribution;
+import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.function.Constant;
+
+/**
+ * Creates functions that will select the initial values of a neuron's
+ * features.
+ *
+ * @version $Id$
+ */
+public class FeatureInitializerFactory {
+ /** Class contains only static methods. */
+ private FeatureInitializerFactory() {}
+
+ /**
+ * Uniform sampling of the given range.
+ *
+ * @param min Lower bound of the range.
+ * @param max Upper bound of the range.
+ * @return an initializer such that the features will be initialized with
+ * values within the given range.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code min >= max}.
+ */
+ public static FeatureInitializer uniform(final double min,
+ final double max) {
+ return randomize(new UniformRealDistribution(min, max),
+ function(new Constant(0), 0, 0));
+ }
+
+ /**
+ * Creates an initializer from a univariate function {@code f(x)}.
+ * The argument {@code x} is set to {@code init} at the first call
+ * and will be incremented at each call.
+ *
+ * @param f Function.
+ * @param init Initial value.
+ * @param inc Increment
+ * @return the initializer.
+ */
+ public static FeatureInitializer function(final UnivariateFunction f,
+ final double init,
+ final double inc) {
+ return new FeatureInitializer() {
+ /** Argument. */
+ private double arg = init;
+
+ /** {@inheritDoc} */
+ public double value() {
+ final double result = f.value(arg);
+ arg += inc;
+ return result;
+ }
+ };
+ }
+
+ /**
+ * Adds some amount of random data to the given initializer.
+ *
+ * @param random Random variable distribution.
+ * @param orig Original initializer.
+ * @return an initializer whose {@link FeatureInitializer#value() value}
+ * method will return {@code orig.value() + random.sample()}.
+ */
+ public static FeatureInitializer randomize(final RealDistribution random,
+ final FeatureInitializer orig) {
+ return new FeatureInitializer() {
+ /** {@inheritDoc} */
+ public double value() {
+ return orig.value() + random.sample();
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java
new file mode 100644
index 000000000..b01ef63cd
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.util.HashMap;
+import java.util.Collection;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.exception.NoDataException;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * Utilities for network maps.
+ *
+ * @version $Id$
+ */
+public class MapUtils {
+ /**
+ * Class contains only static methods.
+ */
+ private MapUtils() {}
+
+ /**
+ * Finds the neuron that best matches the given features.
+ *
+ * @param features Data.
+ * @param neurons List of neurons to scan. If the list is empty
+ * {@code null} will be returned.
+ * @param distance Distance function. The neuron's features are
+ * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
+ * @return the neuron whose features are closest to the given data.
+ * @throws org.apache.commons.math3.exception.DimensionMismatchException
+ * if the size of the input is not compatible with the neurons features
+ * size.
+ */
+ public static Neuron findBest(double[] features,
+ Iterable neurons,
+ DistanceMeasure distance) {
+ Neuron best = null;
+ double min = Double.POSITIVE_INFINITY;
+ for (final Neuron n : neurons) {
+ final double d = distance.compute(n.getFeatures(), features);
+ if (d < min) {
+ min = d;
+ best = n;
+ }
+ }
+
+ return best;
+ }
+
+ /**
+ * Finds the two neurons that best match the given features.
+ *
+ * @param features Data.
+ * @param neurons List of neurons to scan. If the list is empty
+ * {@code null} will be returned.
+ * @param distance Distance function. The neuron's features are
+ * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
+ * @return the two neurons whose features are closest to the given data.
+ * @throws org.apache.commons.math3.exception.DimensionMismatchException
+ * if the size of the input is not compatible with the neurons features
+ * size.
+ */
+ public static Pair findBestAndSecondBest(double[] features,
+ Iterable neurons,
+ DistanceMeasure distance) {
+ Neuron[] best = { null, null };
+ double[] min = { Double.POSITIVE_INFINITY,
+ Double.POSITIVE_INFINITY };
+ for (final Neuron n : neurons) {
+ final double d = distance.compute(n.getFeatures(), features);
+ if (d < min[0]) {
+ // Replace second best with old best.
+ min[1] = min[0];
+ best[1] = best[0];
+
+ // Store current as new best.
+ min[0] = d;
+ best[0] = n;
+ } else if (d < min[1]) {
+ // Replace old second best with current.
+ min[1] = d;
+ best[1] = n;
+ }
+ }
+
+ return new Pair(best[0], best[1]);
+ }
+
+ /**
+ * Computes the
+ * U-matrix of a two-dimensional map.
+ *
+ * @param map Network.
+ * @param distance Function to use for computing the average
+ * distance from a neuron to its neighbours.
+ * @return the matrix of average distances.
+ */
+ public static double[][] computeU(NeuronSquareMesh2D map,
+ DistanceMeasure distance) {
+ final int numRows = map.getNumberOfRows();
+ final int numCols = map.getNumberOfColumns();
+ final double[][] uMatrix = new double[numRows][numCols];
+
+ final Network net = map.getNetwork();
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final Neuron neuron = map.getNeuron(i, j);
+ final Collection neighbours = net.getNeighbours(neuron);
+ final double[] features = neuron.getFeatures();
+
+ double d = 0;
+ int count = 0;
+ for (Neuron n : neighbours) {
+ ++count;
+ d += distance.compute(features, n.getFeatures());
+ }
+
+ uMatrix[i][j] = d / count;
+ }
+ }
+
+ return uMatrix;
+ }
+
+ /**
+ * Computes the "hit" histogram of a two-dimensional map.
+ *
+ * @param data Feature vectors.
+ * @param map Network.
+ * @param distance Function to use for determining the best matching unit.
+ * @return the number of hits for each neuron in the map.
+ */
+ public static int[][] computeHitHistogram(Iterable data,
+ NeuronSquareMesh2D map,
+ DistanceMeasure distance) {
+ final HashMap hit = new HashMap();
+ final Network net = map.getNetwork();
+
+ for (double[] f : data) {
+ final Neuron best = findBest(f, net, distance);
+ final Integer count = hit.get(best);
+ if (count == null) {
+ hit.put(best, 1);
+ } else {
+ hit.put(best, count + 1);
+ }
+ }
+
+ // Copy the histogram data into a 2D map.
+ final int numRows = map.getNumberOfRows();
+ final int numCols = map.getNumberOfColumns();
+ final int[][] histo = new int[numRows][numCols];
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final Neuron neuron = map.getNeuron(i, j);
+ final Integer count = hit.get(neuron);
+ if (count == null) {
+ histo[i][j] = 0;
+ } else {
+ histo[i][j] = count;
+ }
+ }
+ }
+
+ return histo;
+ }
+
+ /**
+ * Computes the quantization error.
+ * The quantization error is the average distance between a feature vector
+ * and its "best matching unit" (closest neuron).
+ *
+ * @param data Feature vectors.
+ * @param neurons List of neurons to scan.
+ * @param distance Distance function.
+ * @return the error.
+ * @throws NoDataException if {@code data} is empty.
+ */
+ public static double computeQuantizationError(Iterable data,
+ Iterable neurons,
+ DistanceMeasure distance) {
+ double d = 0;
+ int count = 0;
+ for (double[] f : data) {
+ ++count;
+ d += distance.compute(f, findBest(f, neurons, distance).getFeatures());
+ }
+
+ if (count == 0) {
+ throw new NoDataException();
+ }
+
+ return d / count;
+ }
+
+ /**
+ * Computes the topographic error.
+ * The topographic error is the proportion of data for which first and
+ * second best matching units are not adjacent in the map.
+ *
+ * @param data Feature vectors.
+ * @param net Network.
+ * @param distance Distance function.
+ * @return the error.
+ * @throws NoDataException if {@code data} is empty.
+ */
+ public static double computeTopographicError(Iterable data,
+ Network net,
+ DistanceMeasure distance) {
+ int notAdjacentCount = 0;
+ int count = 0;
+ for (double[] f : data) {
+ ++count;
+ final Pair p = findBestAndSecondBest(f, net, distance);
+ if (!net.getNeighbours(p.getFirst()).contains(p.getSecond())) {
+ // Increment count if first and second best matching units
+ // are not neighbours.
+ ++notAdjacentCount;
+ }
+ }
+
+ if (count == 0) {
+ throw new NoDataException();
+ }
+
+ return ((double) notAdjacentCount) / count;
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
new file mode 100644
index 000000000..f4cf32586
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import java.util.NoSuchElementException;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Comparator;
+import java.util.Collections;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.MathIllegalStateException;
+
+/**
+ * Neural network, composed of {@link Neuron} instances and the links
+ * between them.
+ *
+ * Although updating a neuron's state is thread-safe, modifying the
+ * network's topology (adding or removing links) is not.
+ *
+ * @version $Id$
+ */
+public class Network
+ implements Iterable,
+ Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Neurons. */
+ private final ConcurrentHashMap neuronMap
+ = new ConcurrentHashMap();
+ /** Next available neuron identifier. */
+ private final AtomicLong nextId;
+ /** Neuron's features set size. */
+ private final int featureSize;
+ /** Links. */
+ private final ConcurrentHashMap> linkMap
+ = new ConcurrentHashMap>();
+
+ /**
+ * Comparator that prescribes an order of the neurons according
+ * to the increasing order of their identifier.
+ */
+ public static class NeuronIdentifierComparator
+ implements Comparator,
+ Serializable {
+ /** Version identifier. */
+ private static final long serialVersionUID = 20130207L;
+
+ /** {@inheritDoc} */
+ @Override
+ public int compare(Neuron a,
+ Neuron b) {
+ final long aId = a.getIdentifier();
+ final long bId = b.getIdentifier();
+ return aId < bId ? -1 :
+ aId > bId ? 1 : 0;
+ }
+ }
+
+ /**
+ * Constructor with restricted access, solely used for deserialization.
+ *
+ * @param nextId Next available identifier.
+ * @param featureSize Number of features.
+ * @param neuronList Neurons.
+ * @param neighbourIdList Links associated to each of the neurons in
+ * {@code neuronList}.
+ * @throws MathIllegalStateException if an inconsistency is detected
+ * (which probably means that the serialized form has been corrupted).
+ */
+ Network(long nextId,
+ int featureSize,
+ Neuron[] neuronList,
+ long[][] neighbourIdList) {
+ final int numNeurons = neuronList.length;
+ if (numNeurons != neighbourIdList.length) {
+ throw new MathIllegalStateException();
+ }
+
+ for (int i = 0; i < numNeurons; i++) {
+ final Neuron n = neuronList[i];
+ final long id = n.getIdentifier();
+ if (id >= nextId) {
+ throw new MathIllegalStateException();
+ }
+ neuronMap.put(id, n);
+ linkMap.put(id, new HashSet());
+ }
+
+ for (int i = 0; i < numNeurons; i++) {
+ final long aId = neuronList[i].getIdentifier();
+ final Set aLinks = linkMap.get(aId);
+ for (Long bId : neighbourIdList[i]) {
+ if (neuronMap.get(bId) == null) {
+ throw new MathIllegalStateException();
+ }
+ addLinkToLinkSet(aLinks, bId);
+ }
+ }
+
+ this.nextId = new AtomicLong(nextId);
+ this.featureSize = featureSize;
+ }
+
+ /**
+ * @param initialIdentifier Identifier for the first neuron that
+ * will be added to this network.
+ * @param featureSize Size of the neuron's features.
+ */
+ public Network(long initialIdentifier,
+ int featureSize) {
+ nextId = new AtomicLong(initialIdentifier);
+ this.featureSize = featureSize;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public Iterator iterator() {
+ return neuronMap.values().iterator();
+ }
+
+ /**
+ * Creates a list of the neurons, sorted in a custom order.
+ *
+ * @param comparator {@link Comparator} used for sorting the neurons.
+ * @return a list of neurons, sorted in the order prescribed by the
+ * given {@code comparator}.
+ * @see NeuronIdentifierComparator
+ */
+ public Collection getNeurons(Comparator comparator) {
+ final List neurons = new ArrayList();
+ neurons.addAll(neuronMap.values());
+
+ Collections.sort(neurons, comparator);
+
+ return neurons;
+ }
+
+ /**
+ * Creates a neuron and assigns it a unique identifier.
+ *
+ * @param features Initial values for the neuron's features.
+ * @return the neuron's identifier.
+ * @throws DimensionMismatchException if the length of {@code features}
+ * is different from the expected size (as set by the
+ * {@link #Network(long,int) constructor}).
+ */
+ public long createNeuron(double[] features) {
+ if (features.length != featureSize) {
+ throw new DimensionMismatchException(features.length, featureSize);
+ }
+
+ final long id = createNextId();
+ neuronMap.put(id, new Neuron(id, features));
+ linkMap.put(id, new HashSet());
+ return id;
+ }
+
+ /**
+ * Deletes a neuron.
+ * Links from all neighbours to the removed neuron will also be
+ * {@link #deleteLink(Neuron,Neuron) deleted}.
+ *
+ * @param neuron Neuron to be removed from this network.
+ * @throws NoSuchElementException if {@code n} does not belong to
+ * this network.
+ */
+ public void deleteNeuron(Neuron neuron) {
+ final Collection neighbours = getNeighbours(neuron);
+
+ // Delete links to from neighbours.
+ for (Neuron n : neighbours) {
+ deleteLink(n, neuron);
+ }
+
+ // Remove neuron.
+ neuronMap.remove(neuron.getIdentifier());
+ }
+
+ /**
+ * Gets the size of the neurons' features set.
+ *
+ * @return the size of the features set.
+ */
+ public int getFeaturesSize() {
+ return featureSize;
+ }
+
+ /**
+ * Adds a link from neuron {@code a} to neuron {@code b}.
+ * Note: the link is not bi-directional; if a bi-directional link is
+ * required, an additional call must be made with {@code a} and
+ * {@code b} exchanged in the argument list.
+ *
+ * @param a Neuron.
+ * @param b Neuron.
+ * @throws NoSuchElementException if the neurons do not exist in the
+ * network.
+ */
+ public void addLink(Neuron a,
+ Neuron b) {
+ final long aId = a.getIdentifier();
+ final long bId = b.getIdentifier();
+
+ // Check that the neurons belong to this network.
+ if (a != getNeuron(aId)) {
+ throw new NoSuchElementException(Long.toString(aId));
+ }
+ if (b != getNeuron(bId)) {
+ throw new NoSuchElementException(Long.toString(bId));
+ }
+
+ // Add link from "a" to "b".
+ addLinkToLinkSet(linkMap.get(aId), bId);
+ }
+
+ /**
+ * Adds a link to neuron {@code id} in given {@code linkSet}.
+ * Note: no check verifies that the identifier indeed belongs
+ * to this network.
+ *
+ * @param linkSet Neuron identifier.
+ * @param id Neuron identifier.
+ */
+ private void addLinkToLinkSet(Set linkSet,
+ long id) {
+ linkSet.add(id);
+ }
+
+ /**
+ * Deletes the link between neurons {@code a} and {@code b}.
+ *
+ * @param a Neuron.
+ * @param b Neuron.
+ * @throws NoSuchElementException if the neurons do not exist in the
+ * network.
+ */
+ public void deleteLink(Neuron a,
+ Neuron b) {
+ final long aId = a.getIdentifier();
+ final long bId = b.getIdentifier();
+
+ // Check that the neurons belong to this network.
+ if (a != getNeuron(aId)) {
+ throw new NoSuchElementException(Long.toString(aId));
+ }
+ if (b != getNeuron(bId)) {
+ throw new NoSuchElementException(Long.toString(bId));
+ }
+
+ // Delete link from "a" to "b".
+ deleteLinkFromLinkSet(linkMap.get(aId), bId);
+ }
+
+ /**
+ * Deletes a link to neuron {@code id} in given {@code linkSet}.
+ * Note: no check verifies that the identifier indeed belongs
+ * to this network.
+ *
+ * @param linkSet Neuron identifier.
+ * @param id Neuron identifier.
+ */
+ private void deleteLinkFromLinkSet(Set linkSet,
+ long id) {
+ linkSet.remove(id);
+ }
+
+ /**
+ * Retrieves the neuron with the given (unique) {@code id}.
+ *
+ * @param id Identifier.
+ * @return the neuron associated with the given {@code id}.
+ * @throws NoSuchElementException if the neuron does not exist in the
+ * network.
+ */
+ public Neuron getNeuron(long id) {
+ final Neuron n = neuronMap.get(id);
+ if (n == null) {
+ throw new NoSuchElementException(Long.toString(id));
+ }
+ return n;
+ }
+
+ /**
+ * Retrieves the neurons in the neighbourhood of any neuron in the
+ * {@code neurons} list.
+ * @param neurons Neurons for which to retrieve the neighbours.
+ * @return the list of neighbours.
+ * @see #getNeighbours(Iterable,Iterable)
+ */
+ public Collection getNeighbours(Iterable neurons) {
+ return getNeighbours(neurons, null);
+ }
+
+ /**
+ * Retrieves the neurons in the neighbourhood of any neuron in the
+ * {@code neurons} list.
+ * The {@code exclude} list allows to retrieve the "concentric"
+ * neighbourhoods by removing the neurons that belong to the inner
+ * "circles".
+ *
+ * @param neurons Neurons for which to retrieve the neighbours.
+ * @param exclude Neurons to exclude from the returned list.
+ * Can be {@code null}.
+ * @return the list of neighbours.
+ */
+ public Collection getNeighbours(Iterable neurons,
+ Iterable exclude) {
+ final Set idList = new HashSet();
+
+ for (Neuron n : neurons) {
+ idList.addAll(linkMap.get(n.getIdentifier()));
+ }
+ if (exclude != null) {
+ for (Neuron n : exclude) {
+ idList.remove(n.getIdentifier());
+ }
+ }
+
+ final List neuronList = new ArrayList();
+ for (Long id : idList) {
+ neuronList.add(getNeuron(id));
+ }
+
+ return neuronList;
+ }
+
+ /**
+ * Retrieves the neighbours of the given neuron.
+ *
+ * @param neuron Neuron for which to retrieve the neighbours.
+ * @return the list of neighbours.
+ * @see #getNeighbours(Neuron,Iterable)
+ */
+ public Collection getNeighbours(Neuron neuron) {
+ return getNeighbours(neuron, null);
+ }
+
+ /**
+ * Retrieves the neighbours of the given neuron.
+ *
+ * @param neuron Neuron for which to retrieve the neighbours.
+ * @param exclude Neurons to exclude from the returned list.
+ * Can be {@code null}.
+ * @return the list of neighbours.
+ */
+ public Collection getNeighbours(Neuron neuron,
+ Iterable exclude) {
+ final Set idList = linkMap.get(neuron.getIdentifier());
+ if (exclude != null) {
+ for (Neuron n : exclude) {
+ idList.remove(n.getIdentifier());
+ }
+ }
+
+ final List neuronList = new ArrayList();
+ for (Long id : idList) {
+ neuronList.add(getNeuron(id));
+ }
+
+ return neuronList;
+ }
+
+ /**
+ * Creates a neuron identifier.
+ *
+ * @return a value that will serve as a unique identifier.
+ */
+ private Long createNextId() {
+ return nextId.getAndIncrement();
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
+ final long[][] neighbourIdList = new long[neuronList.length][];
+
+ for (int i = 0; i < neuronList.length; i++) {
+ final Collection neighbours = getNeighbours(neuronList[i]);
+ final long[] neighboursId = new long[neighbours.size()];
+ int count = 0;
+ for (Neuron n : neighbours) {
+ neighboursId[count] = n.getIdentifier();
+ ++count;
+ }
+ neighbourIdList[i] = neighboursId;
+ }
+
+ return new SerializationProxy(nextId.get(),
+ featureSize,
+ neuronList,
+ neighbourIdList);
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Next identifier. */
+ private final long nextId;
+ /** Number of features. */
+ private final int featureSize;
+ /** Neurons. */
+ private final Neuron[] neuronList;
+ /** Links. */
+ private final long[][] neighbourIdList;
+
+ /**
+ * @param nextId Next available identifier.
+ * @param featureSize Number of features.
+ * @param neuronList Neurons.
+ * @param neighbourIdList Links associated to each of the neurons in
+ * {@code neuronList}.
+ */
+ SerializationProxy(long nextId,
+ int featureSize,
+ Neuron[] neuronList,
+ long[][] neighbourIdList) {
+ this.nextId = nextId;
+ this.featureSize = featureSize;
+ this.neuronList = neuronList;
+ this.neighbourIdList = neighbourIdList;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Network} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new Network(nextId,
+ featureSize,
+ neuronList,
+ neighbourIdList);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
new file mode 100644
index 000000000..a64221a0f
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.commons.math3.util.Precision;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+
+
+/**
+ * Describes a neuron element of a neural network.
+ *
+ * This class aims to be thread-safe.
+ *
+ * @version $Id$
+ */
+public class Neuron implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Identifier. */
+ private final long identifier;
+ /** Length of the feature set. */
+ private final int size;
+ /** Neuron data. */
+ private final AtomicReference features;
+
+ /**
+ * Creates a neuron.
+ * The size of the feature set is fixed to the length of the given
+ * argument.
+ *
+ * Constructor is package-private: Neurons must be
+ * {@link Network#createNeuron(double[]) created} by the network
+ * instance to which they will belong.
+ *
+ * @param identifier Identifier (assigned by the {@link Network}).
+ * @param features Initial values of the feature set.
+ */
+ Neuron(long identifier,
+ double[] features) {
+ this.identifier = identifier;
+ this.size = features.length;
+ this.features = new AtomicReference(features.clone());
+ }
+
+ /**
+ * Gets the neuron's identifier.
+ *
+ * @return the identifier.
+ */
+ public long getIdentifier() {
+ return identifier;
+ }
+
+ /**
+ * Gets the length of the feature set.
+ *
+ * @return the number of features.
+ */
+ public int getSize() {
+ return size;
+ }
+
+ /**
+ * Gets the neuron's features.
+ *
+ * @return a copy of the neuron's features.
+ */
+ public double[] getFeatures() {
+ return features.get().clone();
+ }
+
+ /**
+ * Tries to atomically update the neuron's features.
+ * Update will be performed only if the expected values match the
+ * current values.
+ * In effect, when concurrent threads call this method, the state
+ * could be modified by one, so that it does not correspond to the
+ * the state assumed by another.
+ * Typically, a caller {@link #getFeatures() retrieves the current state},
+ * and uses it to compute the new state.
+ * During this computation, another thread might have done the same
+ * thing, and updated the state: If the current thread were to proceed
+ * with its own update, it would overwrite the new state (which might
+ * already have been used by yet other threads).
+ * To prevent this, the method does not perform the update when a
+ * concurrent modification has been detected, and returns {@code false}.
+ * When this happens, the caller should fetch the new current state,
+ * redo its computation, and call this method again.
+ *
+ * @param expect Current values of the features, as assumed by the caller.
+ * Update will never succeed if the contents of this array does not match
+ * the values returned by {@link #getFeatures()}.
+ * @param update Features's new values.
+ * @return {@code true} if the update was successful, {@code false}
+ * otherwise.
+ * @throws DimensionMismatchException if the length of {@code update} is
+ * not the same as specified in the {@link #Neuron(long,double[])
+ * constructor}.
+ */
+ public boolean compareAndSetFeatures(double[] expect,
+ double[] update) {
+ if (update.length != size) {
+ throw new DimensionMismatchException(update.length, size);
+ }
+
+ // Get the internal reference. Note that this must not be a copy;
+ // otherwise the "compareAndSet" below will always fail.
+ final double[] current = features.get();
+ if (!containSameValues(current, expect)) {
+ // Some other thread already modified the state.
+ return false;
+ }
+
+ if (features.compareAndSet(current, update.clone())) {
+ // The current thread could atomically update the state.
+ return true;
+ } else {
+ // Some other thread came first.
+ return false;
+ }
+ }
+
+ /**
+ * Checks whether the contents of both arrays is the same.
+ *
+ * @param current Current values.
+ * @param expect Expected values.
+ * @throws DimensionMismatchException if the length of {@code expected}
+ * is not the same as specified in the {@link #Neuron(long,double[])
+ * constructor}.
+ * @return {@code true} if the arrays contain the same values.
+ */
+ private boolean containSameValues(double[] current,
+ double[] expect) {
+ if (expect.length != size) {
+ throw new DimensionMismatchException(expect.length, size);
+ }
+
+ for (int i = 0; i < size; i++) {
+ if (!Precision.equals(current[i], expect[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ return new SerializationProxy(identifier,
+ features.get());
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130207L;
+ /** Features. */
+ private final double[] features;
+ /** Identifier. */
+ private final long identifier;
+
+ /**
+ * @param identifier Identifier.
+ * @param features Features.
+ */
+ SerializationProxy(long identifier,
+ double[] features) {
+ this.identifier = identifier;
+ this.features = features;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Neuron} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new Neuron(identifier,
+ features);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java
new file mode 100644
index 000000000..aa5b00ebc
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/SquareNeighbourhood.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+/**
+ * Defines neighbourhood types.
+ *
+ * @version $Id$
+ */
+public enum SquareNeighbourhood {
+ /**
+ * : in two dimensions, each (internal)
+ * neuron has four neighbours.
+ */
+ VON_NEUMANN,
+ /**
+ * : in two dimensions, each (internal)
+ * neuron has eight neighbours.
+ */
+ MOORE,
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java
new file mode 100644
index 000000000..eca9f28a9
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/UpdateAction.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+/**
+ * Describes how to update the network in response to a training
+ * sample.
+ *
+ * @version $Id$
+ */
+public interface UpdateAction {
+ /**
+ * Updates the network in response to the sample {@code features}.
+ *
+ * @param net Network.
+ * @param features Training data.
+ */
+ void update(Network net, double[] features);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java
new file mode 100644
index 000000000..f904c523d
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronString.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.oned;
+
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.OutOfRangeException;
+
+/**
+ * Neural network with the topology of a one-dimensional line.
+ * Each neuron defines one point on the line.
+ *
+ * @version $Id$
+ */
+public class NeuronString implements Serializable {
+ /** Underlying network. */
+ private final Network network;
+ /** Number of neurons. */
+ private final int size;
+ /** Wrap. */
+ private final boolean wrap;
+
+ /**
+ * Mapping of the 1D coordinate to the neuron identifiers
+ * (attributed by the {@link #network} instance).
+ */
+ private final long[] identifiers;
+
+ /**
+ * Constructor with restricted access, solely used for deserialization.
+ *
+ * @param wrap Whether to wrap the dimension (i.e the first and last
+ * neurons will be linked together).
+ * @param featuresList Arrays that will initialize the features sets of
+ * the network's neurons.
+ * @throws NumberIsTooSmallException if {@code num < 2}.
+ */
+ NeuronString(boolean wrap,
+ double[][] featuresList) {
+ size = featuresList.length;
+
+ if (size < 2) {
+ throw new NumberIsTooSmallException(size, 2, true);
+ }
+
+ this.wrap = wrap;
+
+ final int fLen = featuresList[0].length;
+ network = new Network(0, fLen);
+ identifiers = new long[size];
+
+ // Add neurons.
+ for (int i = 0; i < size; i++) {
+ identifiers[i] = network.createNeuron(featuresList[i]);
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Creates a one-dimensional network:
+ * Each neuron not located on the border of the mesh has two
+ * neurons linked to it.
+ *
+ * The links are bi-directional.
+ * Neurons created successively are neighbours (i.e. there are
+ * links between them).
+ *
+ * The topology of the network can also be a circle (if the
+ * dimension is wrapped).
+ *
+ * @param num Number of neurons.
+ * @param wrap Whether to wrap the dimension (i.e the first and last
+ * neurons will be linked together).
+ * @param featureInit Arrays that will initialize the features sets of
+ * the network's neurons.
+ * @throws NumberIsTooSmallException if {@code num < 2}.
+ */
+ public NeuronString(int num,
+ boolean wrap,
+ FeatureInitializer[] featureInit) {
+ if (num < 2) {
+ throw new NumberIsTooSmallException(num, 2, true);
+ }
+
+ size = num;
+ this.wrap = wrap;
+ identifiers = new long[num];
+
+ final int fLen = featureInit.length;
+ network = new Network(0, fLen);
+
+ // Add neurons.
+ for (int i = 0; i < num; i++) {
+ final double[] features = new double[fLen];
+ for (int fIndex = 0; fIndex < fLen; fIndex++) {
+ features[fIndex] = featureInit[fIndex].value();
+ }
+ identifiers[i] = network.createNeuron(features);
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Retrieves the underlying network.
+ * A reference is returned (enabling, for example, the network to be
+ * trained).
+ * This also implies that calling methods that modify the {@link Network}
+ * topology may cause this class to become inconsistent.
+ *
+ * @return the network.
+ */
+ public Network getNetwork() {
+ return network;
+ }
+
+ /**
+ * Gets the number of neurons.
+ *
+ * @return the number of neurons.
+ */
+ public int getSize() {
+ return size;
+ }
+
+ /**
+ * Retrieves the features set from the neuron at location
+ * {@code i} in the map.
+ *
+ * @param i Neuron index.
+ * @return the features of the neuron at index {@code i}.
+ * @throws OutOfRangeException if {@code i} is out of range.
+ */
+ public double[] getFeatures(int i) {
+ if (i < 0 ||
+ i >= size) {
+ throw new OutOfRangeException(i, 0, size - 1);
+ }
+
+ return network.getNeuron(identifiers[i]).getFeatures();
+ }
+
+ /**
+ * Creates the neighbour relationships between neurons.
+ */
+ private void createLinks() {
+ for (int i = 0; i < size - 1; i++) {
+ network.addLink(network.getNeuron(i), network.getNeuron(i + 1));
+ }
+ for (int i = size - 1; i > 0; i--) {
+ network.addLink(network.getNeuron(i), network.getNeuron(i - 1));
+ }
+ if (wrap) {
+ network.addLink(network.getNeuron(0), network.getNeuron(size - 1));
+ network.addLink(network.getNeuron(size - 1), network.getNeuron(0));
+ }
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ final double[][] featuresList = new double[size][];
+ for (int i = 0; i < size; i++) {
+ featuresList[i] = getFeatures(i);
+ }
+
+ return new SerializationProxy(wrap,
+ featuresList);
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130226L;
+ /** Wrap. */
+ private final boolean wrap;
+ /** Neurons' features. */
+ private final double[][] featuresList;
+
+ /**
+ * @param wrap Whether the dimension is wrapped.
+ * @param featuresList List of neurons features.
+ * {@code neuronList}.
+ */
+ SerializationProxy(boolean wrap,
+ double[][] featuresList) {
+ this.wrap = wrap;
+ this.featuresList = featuresList;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Neuron} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new NeuronString(wrap,
+ featuresList);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java
new file mode 100644
index 000000000..0b47fae99
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/oned/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * One-dimensional neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.oned;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java
new file mode 100644
index 000000000..d8e907e77
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java
new file mode 100644
index 000000000..bfacbc63a
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTask.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import java.util.Iterator;
+import org.apache.commons.math3.ml.neuralnet.Network;
+
+/**
+ * Trainer for Kohonen's Self-Organizing Map.
+ *
+ * @version $Id$
+ */
+public class KohonenTrainingTask implements Runnable {
+ /** SOFM to be trained. */
+ private final Network net;
+ /** Training data. */
+ private final Iterator featuresIterator;
+ /** Update procedure. */
+ private final KohonenUpdateAction updateAction;
+
+ /**
+ * Creates a (sequential) trainer for the given network.
+ *
+ * @param net Network to be trained with the SOFM algorithm.
+ * @param featuresIterator Training data iterator.
+ * @param updateAction SOFM update procedure.
+ */
+ public KohonenTrainingTask(Network net,
+ Iterator featuresIterator,
+ KohonenUpdateAction updateAction) {
+ this.net = net;
+ this.featuresIterator = featuresIterator;
+ this.updateAction = updateAction;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ public void run() {
+ while (featuresIterator.hasNext()) {
+ updateAction.update(net, featuresIterator.next());
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java
new file mode 100644
index 000000000..3831e8a85
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.UpdateAction;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.analysis.function.Gaussian;
+
+/**
+ * Update formula for
+ * Kohonen's Self-Organizing Map.
+ *
+ * The {@link #update(Network,double[]) update} method modifies the
+ * features {@code w} of the "winning" neuron and its neighbours
+ * according to the following rule:
+ *
+ * wnew = wold + α e(-d / σ) * (sample - wold)
+ *
+ * where
+ *
+ * - α is the current learning rate,
+ * - σ is the current neighbourhood size, and
+ * - {@code d} is the number of links to traverse in order to reach
+ * the neuron from the winning neuron.
+ *
+ *
+ * This class is thread-safe as long as the arguments passed to the
+ * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
+ * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
+ * classes.
+ *
+ * Each call to the {@link #update(Network,double[]) update} method
+ * will increment the internal counter used to compute the current
+ * values for
+ *
+ * - the learning rate, and
+ * - the neighbourhood size.
+ *
+ * Consequently, the function instances that compute those values (passed
+ * to the constructor of this class) must take into account whether this
+ * class's instance will be shared by multiple threads, as this will impact
+ * the training process.
+ *
+ * @version $Id$
+ */
+public class KohonenUpdateAction implements UpdateAction {
+ /** Distance function. */
+ private final DistanceMeasure distance;
+ /** Learning factor update function. */
+ private final LearningFactorFunction learningFactor;
+ /** Neighbourhood size update function. */
+ private final NeighbourhoodSizeFunction neighbourhoodSize;
+ /** Number of calls to {@link #update(Network,double[])}. */
+ private final AtomicLong numberOfCalls = new AtomicLong(-1);
+
+ /**
+ * @param distance Distance function.
+ * @param learningFactor Learning factor update function.
+ * @param neighbourhoodSize Neighbourhood size update function.
+ */
+ public KohonenUpdateAction(DistanceMeasure distance,
+ LearningFactorFunction learningFactor,
+ NeighbourhoodSizeFunction neighbourhoodSize) {
+ this.distance = distance;
+ this.learningFactor = learningFactor;
+ this.neighbourhoodSize = neighbourhoodSize;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void update(Network net,
+ double[] features) {
+ final long numCalls = numberOfCalls.incrementAndGet();
+ final double currentLearning = learningFactor.value(numCalls);
+ final Neuron best = findAndUpdateBestNeuron(net,
+ features,
+ currentLearning);
+
+ final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
+ // The farther away the neighbour is from the winning neuron, the
+ // smaller the learning rate will become.
+ final Gaussian neighbourhoodDecay
+ = new Gaussian(currentLearning,
+ 0,
+ 1d / currentNeighbourhood);
+
+ if (currentNeighbourhood > 0) {
+ // Initial set of neurons only contains the winning neuron.
+ Collection neighbours = new HashSet();
+ neighbours.add(best);
+ // Winning neuron must be excluded from the neighbours.
+ final HashSet exclude = new HashSet();
+ exclude.add(best);
+
+ int radius = 1;
+ do {
+ // Retrieve immediate neighbours of the current set of neurons.
+ neighbours = net.getNeighbours(neighbours, exclude);
+
+ // Update all the neighbours.
+ for (Neuron n : neighbours) {
+ updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius));
+ }
+
+ // Add the neighbours to the exclude list so that they will
+ // not be update more than once per training step.
+ exclude.addAll(neighbours);
+ ++radius;
+ } while (radius <= currentNeighbourhood);
+ }
+ }
+
+ /**
+ * Retrieves the number of calls to the {@link #update(Network,double[]) update}
+ * method.
+ *
+ * @return the current number of calls.
+ */
+ public long getNumberOfCalls() {
+ return numberOfCalls.get();
+ }
+
+ /**
+ * Atomically updates the given neuron.
+ *
+ * @param n Neuron to be updated.
+ * @param features Training data.
+ * @param learningRate Learning factor.
+ */
+ private void updateNeighbouringNeuron(Neuron n,
+ double[] features,
+ double learningRate) {
+ while (true) {
+ final double[] expect = n.getFeatures();
+ final double[] update = computeFeatures(expect,
+ features,
+ learningRate);
+ if (n.compareAndSetFeatures(expect, update)) {
+ break;
+ }
+ }
+ }
+
+ /**
+ * Searches for the neuron whose features are closest to the given
+ * sample, and atomically updates its features.
+ *
+ * @param net Network.
+ * @param features Sample data.
+ * @param learningRate Current learning factor.
+ * @return the winning neuron.
+ */
+ private Neuron findAndUpdateBestNeuron(Network net,
+ double[] features,
+ double learningRate) {
+ while (true) {
+ final Neuron best = MapUtils.findBest(features, net, distance);
+
+ final double[] expect = best.getFeatures();
+ final double[] update = computeFeatures(expect,
+ features,
+ learningRate);
+ if (best.compareAndSetFeatures(expect, update)) {
+ return best;
+ }
+
+ // If another thread modified the state of the winning neuron,
+ // it may not be the best match anymore for the given training
+ // sample: Hence, the winner search is performed again.
+ }
+ }
+
+ /**
+ * Computes the new value of the features set.
+ *
+ * @param current Current values of the features.
+ * @param sample Training data.
+ * @param learningRate Learning factor.
+ * @return the new values for the features.
+ */
+ private double[] computeFeatures(double[] current,
+ double[] sample,
+ double learningRate) {
+ final ArrayRealVector c = new ArrayRealVector(current, false);
+ final ArrayRealVector s = new ArrayRealVector(sample, false);
+ // c + learningRate * (s - c)
+ return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray();
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java
new file mode 100644
index 000000000..341e2ce22
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunction.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+/**
+ * Provides the learning rate as a function of the number of calls
+ * already performed during the learning task.
+ *
+ * @version $Id$
+ */
+public interface LearningFactorFunction {
+ /**
+ * Computes the learning rate at the current call.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ double value(long numCall);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java
new file mode 100644
index 000000000..98c24facd
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactory.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction;
+import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction;
+import org.apache.commons.math3.exception.OutOfRangeException;
+
+/**
+ * Factory for creating instances of {@link LearningFactorFunction}.
+ *
+ * @version $Id$
+ */
+public class LearningFactorFunctionFactory {
+ /** Class contains only static methods. */
+ private LearningFactorFunctionFactory() {}
+
+ /**
+ * Creates an exponential decay {@link LearningFactorFunction function}.
+ * It will compute a e-x / b
,
+ * where {@code x} is the (integer) independent variable and
+ *
+ * a = initValue
+ * b = -numCall / ln(valueAtNumCall / initValue)
+ *
+ *
+ * @param initValue Initial value, i.e.
+ * {@link LearningFactorFunction#value(long) value(0)}.
+ * @param valueAtNumCall Value of the function at {@code numCall}.
+ * @param numCall Argument for which the function returns
+ * {@code valueAtNumCall}.
+ * @return the learning factor function.
+ * @throws org.apache.commons.math3.exception.OutOfRangeException
+ * if {@code initValue <= 0} or {@code initValue > 1}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code valueAtNumCall <= 0}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code valueAtNumCall >= initValue}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static LearningFactorFunction exponentialDecay(final double initValue,
+ final double valueAtNumCall,
+ final long numCall) {
+ if (initValue <= 0 ||
+ initValue > 1) {
+ throw new OutOfRangeException(initValue, 0, 1);
+ }
+
+ return new LearningFactorFunction() {
+ /** DecayFunction. */
+ private final ExponentialDecayFunction decay
+ = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall);
+
+ /** {@inheritDoc} */
+ @Override
+ public double value(long n) {
+ return decay.value(n);
+ }
+ };
+ }
+
+ /**
+ * Creates an sigmoid-like {@code LearningFactorFunction function}.
+ * The function {@code f} will have the following properties:
+ *
+ * - {@code f(0) = initValue}
+ * - {@code numCall} is the inflexion point
+ * - {@code slope = f'(numCall)}
+ *
+ *
+ * @param initValue Initial value, i.e.
+ * {@link LearningFactorFunction#value(long) value(0)}.
+ * @param slope Value of the function derivative at {@code numCall}.
+ * @param numCall Inflexion point.
+ * @return the learning factor function.
+ * @throws org.apache.commons.math3.exception.OutOfRangeException
+ * if {@code initValue <= 0} or {@code initValue > 1}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code slope >= 0}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static LearningFactorFunction quasiSigmoidDecay(final double initValue,
+ final double slope,
+ final long numCall) {
+ if (initValue <= 0 ||
+ initValue > 1) {
+ throw new OutOfRangeException(initValue, 0, 1);
+ }
+
+ return new LearningFactorFunction() {
+ /** DecayFunction. */
+ private final QuasiSigmoidDecayFunction decay
+ = new QuasiSigmoidDecayFunction(initValue, slope, numCall);
+
+ /** {@inheritDoc} */
+ @Override
+ public double value(long n) {
+ return decay.value(n);
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java
new file mode 100644
index 000000000..84362a4af
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunction.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+/**
+ * Provides the network neighbourhood's size as a function of the
+ * number of calls already performed during the learning task.
+ * The "neighbourhood" is the set of neurons that can be reached
+ * by traversing at most the number of links returned by this
+ * function.
+ *
+ * @version $Id$
+ */
+public interface NeighbourhoodSizeFunction {
+ /**
+ * Computes the neighbourhood size at the current call.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ int value(long numCall);
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java
new file mode 100644
index 000000000..cf185c454
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactory.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction;
+import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction;
+import org.apache.commons.math3.util.FastMath;
+
+/**
+ * Factory for creating instances of {@link NeighbourhoodSizeFunction}.
+ *
+ * @version $Id$
+ */
+public class NeighbourhoodSizeFunctionFactory {
+ /** Class contains only static methods. */
+ private NeighbourhoodSizeFunctionFactory() {}
+
+ /**
+ * Creates an exponential decay {@link NeighbourhoodSizeFunction function}.
+ * It will compute a e-x / b
,
+ * where {@code x} is the (integer) independent variable and
+ *
+ * a = initValue
+ * b = -numCall / ln(valueAtNumCall / initValue)
+ *
+ *
+ * @param initValue Initial value, i.e.
+ * {@link NeighbourhoodSizeFunction#value(long) value(0)}.
+ * @param valueAtNumCall Value of the function at {@code numCall}.
+ * @param numCall Argument for which the function returns
+ * {@code valueAtNumCall}.
+ * @return the neighbourhood size function.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code initValue <= 0}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code valueAtNumCall <= 0}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code valueAtNumCall >= initValue}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static NeighbourhoodSizeFunction exponentialDecay(final double initValue,
+ final double valueAtNumCall,
+ final long numCall) {
+ return new NeighbourhoodSizeFunction() {
+ /** DecayFunction. */
+ private final ExponentialDecayFunction decay
+ = new ExponentialDecayFunction(initValue, valueAtNumCall, numCall);
+
+ /** {@inheritDoc} */
+ @Override
+ public int value(long n) {
+ return (int) FastMath.rint(decay.value(n));
+ }
+ };
+ }
+
+ /**
+ * Creates an sigmoid-like {@code NeighbourhoodSizeFunction function}.
+ * The function {@code f} will have the following properties:
+ *
+ * - {@code f(0) = initValue}
+ * - {@code numCall} is the inflexion point
+ * - {@code slope = f'(numCall)}
+ *
+ *
+ * @param initValue Initial value, i.e.
+ * {@link NeighbourhoodSizeFunction#value(long) value(0)}.
+ * @param slope Value of the function derivative at {@code numCall}.
+ * @param numCall Inflexion point.
+ * @return the neighbourhood size function.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code initValue <= 0}.
+ * @throws org.apache.commons.math3.exception.NumberIsTooLargeException
+ * if {@code slope >= 0}.
+ * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
+ * if {@code numCall <= 0}.
+ */
+ public static NeighbourhoodSizeFunction quasiSigmoidDecay(final double initValue,
+ final double slope,
+ final long numCall) {
+ return new NeighbourhoodSizeFunction() {
+ /** DecayFunction. */
+ private final QuasiSigmoidDecayFunction decay
+ = new QuasiSigmoidDecayFunction(initValue, slope, numCall);
+
+ /** {@inheritDoc} */
+ @Override
+ public int value(long n) {
+ return (int) FastMath.rint(decay.value(n));
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java
new file mode 100644
index 000000000..60c3c61a2
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Self Organizing Feature Map.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java
new file mode 100644
index 000000000..ee33528c7
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunction.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.apache.commons.math3.util.FastMath;
+
+/**
+ * Exponential decay function: a e-x / b
,
+ * where {@code x} is the (integer) independent variable.
+ *
+ * Class is immutable.
+ *
+ * @version $Id$
+ */
+public class ExponentialDecayFunction {
+ /** Factor {@code a}. */
+ private final double a;
+ /** Factor {@code 1 / b}. */
+ private final double oneOverB;
+
+ /**
+ * Creates an instance. It will be such that
+ *
+ * - {@code a = initValue}
+ * - {@code b = -numCall / ln(valueAtNumCall / initValue)}
+ *
+ *
+ * @param initValue Initial value, i.e. {@link #value(long) value(0)}.
+ * @param valueAtNumCall Value of the function at {@code numCall}.
+ * @param numCall Argument for which the function returns
+ * {@code valueAtNumCall}.
+ * @throws NotStrictlyPositiveException if {@code initValue <= 0}.
+ * @throws NotStrictlyPositiveException if {@code valueAtNumCall <= 0}.
+ * @throws NumberIsTooLargeException if {@code valueAtNumCall >= initValue}.
+ * @throws NotStrictlyPositiveException if {@code numCall <= 0}.
+ */
+ public ExponentialDecayFunction(double initValue,
+ double valueAtNumCall,
+ long numCall) {
+ if (initValue <= 0) {
+ throw new NotStrictlyPositiveException(initValue);
+ }
+ if (valueAtNumCall <= 0) {
+ throw new NotStrictlyPositiveException(valueAtNumCall);
+ }
+ if (valueAtNumCall >= initValue) {
+ throw new NumberIsTooLargeException(valueAtNumCall, initValue, false);
+ }
+ if (numCall <= 0) {
+ throw new NotStrictlyPositiveException(numCall);
+ }
+
+ a = initValue;
+ oneOverB = -FastMath.log(valueAtNumCall / initValue) / numCall;
+ }
+
+ /**
+ * Computes a e-numCall / b
.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ public double value(long numCall) {
+ return a * FastMath.exp(-numCall * oneOverB);
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java
new file mode 100644
index 000000000..a6d322c1d
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunction.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.apache.commons.math3.analysis.function.Logistic;
+
+/**
+ * Decay function whose shape is similar to a sigmoid.
+ *
+ * Class is immutable.
+ *
+ * @version $Id$
+ */
+public class QuasiSigmoidDecayFunction {
+ /** Sigmoid. */
+ private final Logistic sigmoid;
+ /** See {@link #value(long)}. */
+ private final double scale;
+
+ /**
+ * Creates an instance.
+ * The function {@code f} will have the following properties:
+ *
+ * - {@code f(0) = initValue}
+ * - {@code numCall} is the inflexion point
+ * - {@code slope = f'(numCall)}
+ *
+ *
+ * @param initValue Initial value, i.e. {@link #value(long) value(0)}.
+ * @param slope Value of the function derivative at {@code numCall}.
+ * @param numCall Inflexion point.
+ * @throws NotStrictlyPositiveException if {@code initValue <= 0}.
+ * @throws NumberIsTooLargeException if {@code slope >= 0}.
+ * @throws NotStrictlyPositiveException if {@code numCall <= 0}.
+ */
+ public QuasiSigmoidDecayFunction(double initValue,
+ double slope,
+ long numCall) {
+ if (initValue <= 0) {
+ throw new NotStrictlyPositiveException(initValue);
+ }
+ if (slope >= 0) {
+ throw new NumberIsTooLargeException(slope, 0, false);
+ }
+ if (numCall <= 1) {
+ throw new NotStrictlyPositiveException(numCall);
+ }
+
+ final double k = initValue;
+ final double m = numCall;
+ final double b = 4 * slope / initValue;
+ final double q = 1;
+ final double a = 0;
+ final double n = 1;
+ sigmoid = new Logistic(k, m, b, q, a, n);
+
+ final double y0 = sigmoid.value(0);
+ scale = k / y0;
+ }
+
+ /**
+ * Computes the value of the learning factor.
+ *
+ * @param numCall Current step of the training task.
+ * @return the value of the function at {@code numCall}.
+ */
+ public double value(long numCall) {
+ return scale * sigmoid.value(numCall);
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java
new file mode 100644
index 000000000..5078ed292
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/util/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Miscellaneous utilities.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
new file mode 100644
index 000000000..c86356bad
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
@@ -0,0 +1,433 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod;
+
+import java.util.List;
+import java.util.ArrayList;
+import java.io.Serializable;
+import java.io.ObjectInputStream;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.OutOfRangeException;
+import org.apache.commons.math3.exception.MathInternalError;
+
+/**
+ * Neural network with the topology of a two-dimensional surface.
+ * Each neuron defines one surface element.
+ *
+ * This network is primarily intended to represent a
+ *
+ * Self Organizing Feature Map.
+ *
+ * @see org.apache.commons.math3.ml.neuralnet.sofm
+ * @version $Id$
+ */
+public class NeuronSquareMesh2D implements Serializable {
+ /** Underlying network. */
+ private final Network network;
+ /** Number of rows. */
+ private final int numberOfRows;
+ /** Number of columns. */
+ private final int numberOfColumns;
+ /** Wrap. */
+ private final boolean wrapRows;
+ /** Wrap. */
+ private final boolean wrapColumns;
+ /** Neighbourhood type. */
+ private final SquareNeighbourhood neighbourhood;
+ /**
+ * Mapping of the 2D coordinates (in the rectangular mesh) to
+ * the neuron identifiers (attributed by the {@link #network}
+ * instance).
+ */
+ private final long[][] identifiers;
+
+ /**
+ * Constructor with restricted access, solely used for deserialization.
+ *
+ * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param wrapColDim Whether to wrap the second dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param neighbourhoodType Neighbourhood type.
+ * @param featuresList Arrays that will initialize the features sets of
+ * the network's neurons.
+ * @throws NumberIsTooSmallException if {@code numRows < 2} or
+ * {@code numCols < 2}.
+ */
+ NeuronSquareMesh2D(boolean wrapRowDim,
+ boolean wrapColDim,
+ SquareNeighbourhood neighbourhoodType,
+ double[][][] featuresList) {
+ numberOfRows = featuresList.length;
+ numberOfColumns = featuresList[0].length;
+
+ if (numberOfRows < 2) {
+ throw new NumberIsTooSmallException(numberOfRows, 2, true);
+ }
+ if (numberOfColumns < 2) {
+ throw new NumberIsTooSmallException(numberOfColumns, 2, true);
+ }
+
+ wrapRows = wrapRowDim;
+ wrapColumns = wrapColDim;
+ neighbourhood = neighbourhoodType;
+
+ final int fLen = featuresList[0][0].length;
+ network = new Network(0, fLen);
+ identifiers = new long[numberOfRows][numberOfColumns];
+
+ // Add neurons.
+ for (int i = 0; i < numberOfRows; i++) {
+ for (int j = 0; j < numberOfColumns; j++) {
+ identifiers[i][j] = network.createNeuron(featuresList[i][j]);
+ }
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Creates a two-dimensional network composed of square cells:
+ * Each neuron not located on the border of the mesh has four
+ * neurons linked to it.
+ *
+ * The links are bi-directional.
+ *
+ * The topology of the network can also be a cylinder (if one
+ * of the dimensions is wrapped) or a torus (if both dimensions
+ * are wrapped).
+ *
+ * @param numRows Number of neurons in the first dimension.
+ * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param numCols Number of neurons in the second dimension.
+ * @param wrapColDim Whether to wrap the second dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param neighbourhoodType Neighbourhood type.
+ * @param featureInit Array of functions that will initialize the
+ * corresponding element of the features set of each newly created
+ * neuron. In particular, the size of this array defines the size of
+ * feature set.
+ * @throws NumberIsTooSmallException if {@code numRows < 2} or
+ * {@code numCols < 2}.
+ */
+ public NeuronSquareMesh2D(int numRows,
+ boolean wrapRowDim,
+ int numCols,
+ boolean wrapColDim,
+ SquareNeighbourhood neighbourhoodType,
+ FeatureInitializer[] featureInit) {
+ if (numRows < 2) {
+ throw new NumberIsTooSmallException(numRows, 2, true);
+ }
+ if (numCols < 2) {
+ throw new NumberIsTooSmallException(numCols, 2, true);
+ }
+
+ numberOfRows = numRows;
+ wrapRows = wrapRowDim;
+ numberOfColumns = numCols;
+ wrapColumns = wrapColDim;
+ neighbourhood = neighbourhoodType;
+ identifiers = new long[numberOfRows][numberOfColumns];
+
+ final int fLen = featureInit.length;
+ network = new Network(0, fLen);
+
+ // Add neurons.
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ final double[] features = new double[fLen];
+ for (int fIndex = 0; fIndex < fLen; fIndex++) {
+ features[fIndex] = featureInit[fIndex].value();
+ }
+ identifiers[i][j] = network.createNeuron(features);
+ }
+ }
+
+ // Add links.
+ createLinks();
+ }
+
+ /**
+ * Retrieves the underlying network.
+ * A reference is returned (enabling, for example, the network to be
+ * trained).
+ * This also implies that calling methods that modify the {@link Network}
+ * topology may cause this class to become inconsistent.
+ *
+ * @return the network.
+ */
+ public Network getNetwork() {
+ return network;
+ }
+
+ /**
+ * Gets the number of neurons in each row of this map.
+ *
+ * @return the number of rows.
+ */
+ public int getNumberOfRows() {
+ return numberOfRows;
+ }
+
+ /**
+ * Gets the number of neurons in each column of this map.
+ *
+ * @return the number of column.
+ */
+ public int getNumberOfColumns() {
+ return numberOfColumns;
+ }
+
+ /**
+ * Retrieves the neuron at location {@code (i, j)} in the map.
+ *
+ * @param i Row index.
+ * @param j Column index.
+ * @return the neuron at {@code (i, j)}.
+ * @throws OutOfRangeException if {@code i} or {@code j} is
+ * out of range.
+ */
+ public Neuron getNeuron(int i,
+ int j) {
+ if (i < 0 ||
+ i >= numberOfRows) {
+ throw new OutOfRangeException(i, 0, numberOfRows - 1);
+ }
+ if (j < 0 ||
+ j >= numberOfColumns) {
+ throw new OutOfRangeException(j, 0, numberOfColumns - 1);
+ }
+
+ return network.getNeuron(identifiers[i][j]);
+ }
+
+ /**
+ * Creates the neighbour relationships between neurons.
+ */
+ private void createLinks() {
+ // "linkEnd" will store the identifiers of the "neighbours".
+ final List linkEnd = new ArrayList();
+ final int iLast = numberOfRows - 1;
+ final int jLast = numberOfColumns - 1;
+ for (int i = 0; i < numberOfRows; i++) {
+ for (int j = 0; j < numberOfColumns; j++) {
+ linkEnd.clear();
+
+ switch (neighbourhood) {
+
+ case MOORE:
+ // Add links to "diagonal" neighbours.
+ if (i > 0) {
+ if (j > 0) {
+ linkEnd.add(identifiers[i - 1][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[i - 1][j + 1]);
+ }
+ }
+ if (i < iLast) {
+ if (j > 0) {
+ linkEnd.add(identifiers[i + 1][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[i + 1][j + 1]);
+ }
+ }
+ if (wrapRows) {
+ if (i == 0) {
+ if (j > 0) {
+ linkEnd.add(identifiers[iLast][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[iLast][j + 1]);
+ }
+ } else if (i == iLast) {
+ if (j > 0) {
+ linkEnd.add(identifiers[0][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[0][j + 1]);
+ }
+ }
+ }
+ if (wrapColumns) {
+ if (j == 0) {
+ if (i > 0) {
+ linkEnd.add(identifiers[i - 1][jLast]);
+ }
+ if (i < iLast) {
+ linkEnd.add(identifiers[i + 1][jLast]);
+ }
+ } else if (j == jLast) {
+ if (i > 0) {
+ linkEnd.add(identifiers[i - 1][0]);
+ }
+ if (i < iLast) {
+ linkEnd.add(identifiers[i + 1][0]);
+ }
+ }
+ }
+ if (wrapRows &&
+ wrapColumns) {
+ if (i == 0 &&
+ j == 0) {
+ linkEnd.add(identifiers[iLast][jLast]);
+ } else if (i == 0 &&
+ j == jLast) {
+ linkEnd.add(identifiers[iLast][0]);
+ } else if (i == iLast &&
+ j == 0) {
+ linkEnd.add(identifiers[0][jLast]);
+ } else if (i == iLast &&
+ j == jLast) {
+ linkEnd.add(identifiers[0][0]);
+ }
+ }
+
+ // Case falls through since the "Moore" neighbourhood
+ // also contains the neurons that belong to the "Von
+ // Neumann" neighbourhood.
+
+ // fallthru (CheckStyle)
+ case VON_NEUMANN:
+ // Links to preceding and following "row".
+ if (i > 0) {
+ linkEnd.add(identifiers[i - 1][j]);
+ }
+ if (i < iLast) {
+ linkEnd.add(identifiers[i + 1][j]);
+ }
+ if (wrapRows) {
+ if (i == 0) {
+ linkEnd.add(identifiers[iLast][j]);
+ } else if (i == iLast) {
+ linkEnd.add(identifiers[0][j]);
+ }
+ }
+
+ // Links to preceding and following "column".
+ if (j > 0) {
+ linkEnd.add(identifiers[i][j - 1]);
+ }
+ if (j < jLast) {
+ linkEnd.add(identifiers[i][j + 1]);
+ }
+ if (wrapColumns) {
+ if (j == 0) {
+ linkEnd.add(identifiers[i][jLast]);
+ } else if (j == jLast) {
+ linkEnd.add(identifiers[i][0]);
+ }
+ }
+ break;
+
+ default:
+ throw new MathInternalError(); // Cannot happen.
+ }
+
+ final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
+ for (long b : linkEnd) {
+ final Neuron bNeuron = network.getNeuron(b);
+ // Link to all neighbours.
+ // The reverse links will be added as the loop proceeds.
+ network.addLink(aNeuron, bNeuron);
+ }
+ }
+ }
+ }
+
+ /**
+ * Prevents proxy bypass.
+ *
+ * @param in Input stream.
+ */
+ private void readObject(ObjectInputStream in) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the proxy instance that will be actually serialized.
+ */
+ private Object writeReplace() {
+ final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
+ for (int i = 0; i < numberOfRows; i++) {
+ for (int j = 0; j < numberOfColumns; j++) {
+ featuresList[i][j] = getNeuron(i, j).getFeatures();
+ }
+ }
+
+ return new SerializationProxy(wrapRows,
+ wrapColumns,
+ neighbourhood,
+ featuresList);
+ }
+
+ /**
+ * Serialization.
+ */
+ private static class SerializationProxy implements Serializable {
+ /** Serializable. */
+ private static final long serialVersionUID = 20130226L;
+ /** Wrap. */
+ private final boolean wrapRows;
+ /** Wrap. */
+ private final boolean wrapColumns;
+ /** Neighbourhood type. */
+ private final SquareNeighbourhood neighbourhood;
+ /** Neurons' features. */
+ private final double[][][] featuresList;
+
+ /**
+ * @param wrapRows Whether the row dimension is wrapped.
+ * @param wrapColumns Whether the column dimension is wrapped.
+ * @param neighbourhood Neighbourhood type.
+ * @param featuresList List of neurons features.
+ * {@code neuronList}.
+ */
+ SerializationProxy(boolean wrapRows,
+ boolean wrapColumns,
+ SquareNeighbourhood neighbourhood,
+ double[][][] featuresList) {
+ this.wrapRows = wrapRows;
+ this.wrapColumns = wrapColumns;
+ this.neighbourhood = neighbourhood;
+ this.featuresList = featuresList;
+ }
+
+ /**
+ * Custom serialization.
+ *
+ * @return the {@link Neuron} for which this instance is the proxy.
+ */
+ private Object readResolve() {
+ return new NeuronSquareMesh2D(wrapRows,
+ wrapColumns,
+ neighbourhood,
+ featuresList);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java
new file mode 100644
index 000000000..41535e8c6
--- /dev/null
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Two-dimensional neural networks.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod;
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java
new file mode 100644
index 000000000..72bf09cd0
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.util.Set;
+import java.util.HashSet;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.ml.neuralnet.oned.NeuronString;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link MapUtils} class.
+ */
+public class MapUtilsTest {
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ */
+ @Test
+ public void testFindClosestNeuron() {
+ final FeatureInitializer init
+ = new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(-0.1, 0.1));
+ final FeatureInitializer[] initArray = { init };
+
+ final Network net = new NeuronString(3, false, initArray).getNetwork();
+ final DistanceMeasure dist = new EuclideanDistance();
+
+ final Set allBest = new HashSet();
+ final Set best = new HashSet();
+ double[][] features;
+
+ // The following tests ensures that
+ // 1. the same neuron is always selected when the input feature is
+ // in the range of the initializer,
+ // 2. different network's neuron have been selected by inputs features
+ // that belong to different ranges.
+
+ best.clear();
+ features = new double[][] {
+ { -1 },
+ { 0.4 },
+ };
+ for (double[] f : features) {
+ best.add(MapUtils.findBest(f, net, dist));
+ }
+ Assert.assertEquals(1, best.size());
+ allBest.addAll(best);
+
+ best.clear();
+ features = new double[][] {
+ { 0.6 },
+ { 1.4 },
+ };
+ for (double[] f : features) {
+ best.add(MapUtils.findBest(f, net, dist));
+ }
+ Assert.assertEquals(1, best.size());
+ allBest.addAll(best);
+
+ best.clear();
+ features = new double[][] {
+ { 1.6 },
+ { 3 },
+ };
+ for (double[] f : features) {
+ best.add(MapUtils.findBest(f, net, dist));
+ }
+ Assert.assertEquals(1, best.size());
+ allBest.addAll(best);
+
+ Assert.assertEquals(3, allBest.size());
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java
new file mode 100644
index 000000000..e7056d9fe
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.io.ByteArrayOutputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Collection;
+import java.util.NoSuchElementException;
+import org.junit.Test;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.random.Well44497b;
+
+/**
+ * Tests for {@link Network}.
+ */
+public class NetworkTest {
+ final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
+
+ @Test
+ public void testGetFeaturesSize() {
+ final FeatureInitializer[] initArray = { init, init, init };
+
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+ Assert.assertEquals(3, net.getFeaturesSize());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1
+ * | |
+ * | |
+ * 2-----3
+ */
+ @Test
+ public void testDeleteLink() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // Delete 0 --> 1.
+ net.deleteLink(net.getNeuron(0),
+ net.getNeuron(1));
+
+ // Link from 0 to 1 was deleted.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ Assert.assertFalse(neighbours.contains(net.getNeuron(1)));
+ // Link from 1 to 0 still exists.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ Assert.assertTrue(neighbours.contains(net.getNeuron(0)));
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1
+ * | |
+ * | |
+ * 2-----3
+ */
+ @Test
+ public void testDeleteNeuron() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+
+ Assert.assertEquals(2, net.getNeighbours(net.getNeuron(0)).size());
+ Assert.assertEquals(2, net.getNeighbours(net.getNeuron(3)).size());
+
+ // Delete neuron 1.
+ net.deleteNeuron(net.getNeuron(1));
+
+ try {
+ net.getNeuron(1);
+ } catch (NoSuchElementException expected) {}
+
+ Assert.assertEquals(1, net.getNeighbours(net.getNeuron(0)).size());
+ Assert.assertEquals(1, net.getNeighbours(net.getNeuron(3)).size());
+ }
+
+ @Test
+ public void testIterationOrder() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(4, false,
+ 3, true,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+
+ boolean isUnspecifiedOrder = false;
+
+ // Check that the default iterator returns the neurons
+ // in an unspecified order.
+ long previousId = Long.MIN_VALUE;
+ for (Neuron n : net) {
+ final long currentId = n.getIdentifier();
+ if (currentId < previousId) {
+ isUnspecifiedOrder = true;
+ break;
+ }
+ previousId = currentId;
+ }
+ Assert.assertTrue(isUnspecifiedOrder);
+
+ // Check that the comparator provides a specific order.
+ isUnspecifiedOrder = false;
+ previousId = Long.MIN_VALUE;
+ for (Neuron n : net.getNeurons(new Network.NeuronIdentifierComparator())) {
+ final long currentId = n.getIdentifier();
+ if (currentId < previousId) {
+ isUnspecifiedOrder = true;
+ break;
+ }
+ previousId = currentId;
+ }
+ Assert.assertFalse(isUnspecifiedOrder);
+ }
+
+ @Test
+ public void testSerialize()
+ throws IOException,
+ ClassNotFoundException {
+ final FeatureInitializer[] initArray = { init };
+ final Network out = new NeuronSquareMesh2D(4, false,
+ 3, true,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+
+ final ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ final ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(out);
+
+ final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ final ObjectInputStream ois = new ObjectInputStream(bis);
+ final Network in = (Network) ois.readObject();
+
+ for (Neuron nOut : out) {
+ final Neuron nIn = in.getNeuron(nOut.getIdentifier());
+
+ // Same values.
+ final double[] outF = nOut.getFeatures();
+ final double[] inF = nIn.getFeatures();
+ Assert.assertEquals(outF.length, inF.length);
+ for (int i = 0; i < outF.length; i++) {
+ Assert.assertEquals(outF[i], inF[i], 0d);
+ }
+
+ // Same neighbours.
+ final Collection outNeighbours = out.getNeighbours(nOut);
+ final Collection inNeighbours = in.getNeighbours(nIn);
+ Assert.assertEquals(outNeighbours.size(), inNeighbours.size());
+ for (Neuron oN : outNeighbours) {
+ Assert.assertTrue(inNeighbours.contains(in.getNeuron(oN.getIdentifier())));
+ }
+ }
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java
new file mode 100644
index 000000000..b03f07d9d
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import java.io.ByteArrayOutputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
+import java.io.IOException;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link Neuron}.
+ */
+public class NeuronTest {
+ @Test
+ public void testGetIdentifier() {
+ final long id = 1234567;
+ final Neuron n = new Neuron(id, new double[] { 0 });
+
+ Assert.assertEquals(id, n.getIdentifier());
+ }
+
+ @Test
+ public void testGetSize() {
+ final double[] features = { -1, -1e-97, 0, 23.456, 9.01e203 } ;
+ final Neuron n = new Neuron(1, features);
+ Assert.assertEquals(features.length, n.getSize());
+ }
+
+ @Test
+ public void testGetFeatures() {
+ final double[] features = { -1, -1e-97, 0, 23.456, 9.01e203 } ;
+ final Neuron n = new Neuron(1, features);
+
+ final double[] f = n.getFeatures();
+ // Accessor returns a copy.
+ Assert.assertFalse(f == features);
+
+ // Values are the same.
+ Assert.assertEquals(features.length, f.length);
+ for (int i = 0; i < features.length; i++) {
+ Assert.assertEquals(features[i], f[i], 0d);
+ }
+ }
+
+ @Test
+ public void testCompareAndSetFeatures() {
+ final Neuron n = new Neuron(1, new double[] { 0 });
+ double[] expect = n.getFeatures();
+ double[] update = new double[] { expect[0] + 1.23 };
+
+ // Test "success".
+ boolean ok = n.compareAndSetFeatures(expect, update);
+ // Check that the update is reported as successful.
+ Assert.assertTrue(ok);
+ // Check that the new value is correct.
+ Assert.assertEquals(update[0], n.getFeatures()[0], 0d);
+
+ // Test "failure".
+ double[] update1 = new double[] { update[0] + 4.56 };
+ // Must return "false" because the neuron has been
+ // updated: a new update can only succeed if "expect"
+ // is set to the new features.
+ ok = n.compareAndSetFeatures(expect, update1);
+ // Check that the update is reported as failed.
+ Assert.assertFalse(ok);
+ // Check that the value was not changed.
+ Assert.assertEquals(update[0], n.getFeatures()[0], 0d);
+ }
+
+ @Test
+ public void testSerialize()
+ throws IOException,
+ ClassNotFoundException {
+ final Neuron out = new Neuron(123, new double[] { -98.76, -1, 0, 1e-23, 543.21, 1e234 });
+ final ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ final ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(out);
+
+ final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ final ObjectInputStream ois = new ObjectInputStream(bis);
+ final Neuron in = (Neuron) ois.readObject();
+
+ // Same identifier.
+ Assert.assertEquals(out.getIdentifier(),
+ in.getIdentifier());
+ // Same values.
+ final double[] outF = out.getFeatures();
+ final double[] inF = in.getFeatures();
+ Assert.assertEquals(outF.length, inF.length);
+ for (int i = 0; i < outF.length; i++) {
+ Assert.assertEquals(outF[i], inF[i], 0d);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/OffsetFeatureInitializer.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/OffsetFeatureInitializer.java
new file mode 100644
index 000000000..15ee1ad51
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/OffsetFeatureInitializer.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet;
+
+import org.junit.Test;
+import org.junit.Assert;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.Well44497b;
+
+/**
+ * Wraps a given initializer.
+ */
+public class OffsetFeatureInitializer
+ implements FeatureInitializer {
+ /** Wrapped initializer. */
+ private final FeatureInitializer orig;
+ /** Offset. */
+ private int inc = 0;
+
+ /**
+ * Creates a new initializer whose {@link #value()} method
+ * will return {@code orig.value() + offset}, where
+ * {@code offset} is automatically incremented by one at
+ * each call.
+ *
+ * @param orig Original initializer.
+ */
+ public OffsetFeatureInitializer(FeatureInitializer orig) {
+ this.orig = orig;
+ }
+
+ /** {@inheritDoc} */
+ public double value() {
+ return orig.value() + inc++;
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronStringTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronStringTest.java
new file mode 100644
index 000000000..7f67130aa
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/oned/NeuronStringTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.oned;
+
+import java.io.ByteArrayOutputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Collection;
+import org.junit.Test;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.random.Well44497b;
+
+/**
+ * Tests for {@link NeuronString} and {@link Network} functionality for
+ * a one-dimensional network.
+ */
+public class NeuronStringTest {
+ final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2-----3
+ */
+ @Test
+ public void testSegmentNetwork() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronString(4, false, initArray).getNetwork();
+
+ Collection neighbours;
+
+ // Neuron 0.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ for (long nId : new long[] { 1 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(1, neighbours.size());
+
+ // Neuron 1.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ for (long nId : new long[] { 0, 2 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+
+ // Neuron 2.
+ neighbours = net.getNeighbours(net.getNeuron(2));
+ for (long nId : new long[] { 1, 3 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+
+ // Neuron 3.
+ neighbours = net.getNeighbours(net.getNeuron(3));
+ for (long nId : new long[] { 2 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(1, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2-----3
+ */
+ @Test
+ public void testCircleNetwork() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronString(4, true, initArray).getNetwork();
+
+ Collection neighbours;
+
+ // Neuron 0.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ for (long nId : new long[] { 1, 3 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+
+ // Neuron 1.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ for (long nId : new long[] { 0, 2 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+
+ // Neuron 2.
+ neighbours = net.getNeighbours(net.getNeuron(2));
+ for (long nId : new long[] { 1, 3 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+
+ // Neuron 3.
+ neighbours = net.getNeighbours(net.getNeuron(3));
+ for (long nId : new long[] { 0, 2 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2-----3-----4
+ */
+ @Test
+ public void testGetNeighboursWithExclude() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronString(5, true, initArray).getNetwork();
+ final Collection exclude = new ArrayList();
+ exclude.add(net.getNeuron(1));
+ final Collection neighbours = net.getNeighbours(net.getNeuron(0),
+ exclude);
+ Assert.assertTrue(neighbours.contains(net.getNeuron(4)));
+ Assert.assertEquals(1, neighbours.size());
+ }
+
+ @Test
+ public void testSerialize()
+ throws IOException,
+ ClassNotFoundException {
+ final FeatureInitializer[] initArray = { init };
+ final NeuronString out = new NeuronString(4, false, initArray);
+
+ final ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ final ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(out);
+
+ final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ final ObjectInputStream ois = new ObjectInputStream(bis);
+ final NeuronString in = (NeuronString) ois.readObject();
+
+ for (Neuron nOut : out.getNetwork()) {
+ final Neuron nIn = in.getNetwork().getNeuron(nOut.getIdentifier());
+
+ // Same values.
+ final double[] outF = nOut.getFeatures();
+ final double[] inF = nIn.getFeatures();
+ Assert.assertEquals(outF.length, inF.length);
+ for (int i = 0; i < outF.length; i++) {
+ Assert.assertEquals(outF[i], inF[i], 0d);
+ }
+
+ // Same neighbours.
+ final Collection outNeighbours = out.getNetwork().getNeighbours(nOut);
+ final Collection inNeighbours = in.getNetwork().getNeighbours(nIn);
+ Assert.assertEquals(outNeighbours.size(), inNeighbours.size());
+ for (Neuron oN : outNeighbours) {
+ Assert.assertTrue(inNeighbours.contains(in.getNetwork().getNeuron(oN.getIdentifier())));
+ }
+ }
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTaskTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTaskTest.java
new file mode 100644
index 000000000..1564a85c7
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenTrainingTaskTest.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Collection;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.ExecutionException;
+import java.io.PrintWriter;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.Assert;
+import org.junit.runner.RunWith;
+import org.apache.commons.math3.RetryRunner;
+import org.apache.commons.math3.Retry;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
+
+/**
+ * Tests for {@link KohonenTrainingTask}
+ */
+@RunWith(RetryRunner.class)
+public class KohonenTrainingTaskTest {
+ @Test
+ public void testTravellerSalesmanSquareTourSequentialSolver() {
+ // Cities (in optimal travel order).
+ final City[] squareOfCities = new City[] {
+ new City("o0", 0, 0),
+ new City("o1", 1, 0),
+ new City("o2", 2, 0),
+ new City("o3", 3, 0),
+ new City("o4", 3, 1),
+ new City("o5", 3, 2),
+ new City("o6", 3, 3),
+ new City("o7", 2, 3),
+ new City("o8", 1, 3),
+ new City("o9", 0, 3),
+ new City("i3", 1, 2),
+ new City("i2", 2, 2),
+ new City("i1", 2, 1),
+ new City("i0", 1, 1),
+ };
+
+ final TravellingSalesmanSolver solver = new TravellingSalesmanSolver(squareOfCities, 2);
+ // printSummary("before.travel.seq.dat", solver);
+ solver.createSequentialTask(15000).run();
+ // printSummary("after.travel.seq.dat", solver);
+ final City[] result = solver.getCityList();
+ Assert.assertEquals(squareOfCities.length,
+ uniqueCities(result).size());
+ final double ratio = computeTravelDistance(squareOfCities) / computeTravelDistance(result);
+ Assert.assertEquals(1, ratio, 1e-1); // We do not require the optimal travel.
+ }
+
+ // Test can sometimes fail: Run several times.
+ @Test
+ @Retry
+ public void testTravellerSalesmanSquareTourParallelSolver() throws ExecutionException {
+ // Cities (in optimal travel order).
+ final City[] squareOfCities = new City[] {
+ new City("o0", 0, 0),
+ new City("o1", 1, 0),
+ new City("o2", 2, 0),
+ new City("o3", 3, 0),
+ new City("o4", 3, 1),
+ new City("o5", 3, 2),
+ new City("o6", 3, 3),
+ new City("o7", 2, 3),
+ new City("o8", 1, 3),
+ new City("o9", 0, 3),
+ new City("i3", 1, 2),
+ new City("i2", 2, 2),
+ new City("i1", 2, 1),
+ new City("i0", 1, 1),
+ };
+
+ final TravellingSalesmanSolver solver = new TravellingSalesmanSolver(squareOfCities, 2);
+ // printSummary("before.travel.par.dat", solver);
+
+ // Parallel execution.
+ final ExecutorService service = Executors.newCachedThreadPool();
+ final Runnable[] tasks = solver.createParallelTasks(3, 5000);
+ final List> execOutput = new ArrayList>();
+ // Run tasks.
+ for (Runnable r : tasks) {
+ execOutput.add(service.submit(r));
+ }
+ // Wait for completion (ignoring return value).
+ try {
+ for (Future> f : execOutput) {
+ f.get();
+ }
+ } catch (InterruptedException ignored) {}
+ // Terminate all threads.
+ service.shutdown();
+
+ // printSummary("after.travel.par.dat", solver);
+ final City[] result = solver.getCityList();
+ Assert.assertEquals(squareOfCities.length,
+ uniqueCities(result).size());
+ final double ratio = computeTravelDistance(squareOfCities) / computeTravelDistance(result);
+ Assert.assertEquals(1, ratio, 1e-1); // We do not require the optimal travel.
+ }
+
+ /**
+ * Creates a map of the travel suggested by the solver.
+ *
+ * @param solver Solver.
+ * @return a 4-columns table: {@code }.
+ */
+ private String travelCoordinatesTable(TravellingSalesmanSolver solver) {
+ final StringBuilder s = new StringBuilder();
+ for (double[] c : solver.getCoordinatesList()) {
+ s.append(c[0]).append(" ").append(c[1]).append(" ");
+ final City city = solver.getClosestCity(c[0], c[1]);
+ final double[] cityCoord = city.getCoordinates();
+ s.append(cityCoord[0]).append(" ").append(cityCoord[1]).append(" ");
+ s.append(" # ").append(city.getName()).append("\n");
+ }
+ return s.toString();
+ }
+
+ /**
+ * Compute the distance covered by the salesman, including
+ * the trip back (from the last to first city).
+ *
+ * @param cityList List of cities visited during the travel.
+ * @return the total distance.
+ */
+ private Collection uniqueCities(City[] cityList) {
+ final Set unique = new HashSet();
+ for (City c : cityList) {
+ unique.add(c);
+ }
+ return unique;
+ }
+
+ /**
+ * Compute the distance covered by the salesman, including
+ * the trip back (from the last to first city).
+ *
+ * @param cityList List of cities visited during the travel.
+ * @return the total distance.
+ */
+ private double computeTravelDistance(City[] cityList) {
+ double dist = 0;
+ for (int i = 0; i < cityList.length; i++) {
+ final double[] currentCoord = cityList[i].getCoordinates();
+ final double[] nextCoord = cityList[(i + 1) % cityList.length].getCoordinates();
+
+ final double xDiff = currentCoord[0] - nextCoord[0];
+ final double yDiff = currentCoord[1] - nextCoord[1];
+
+ dist += FastMath.sqrt(xDiff * xDiff + yDiff * yDiff);
+ }
+
+ return dist;
+ }
+
+ /**
+ * Prints a summary of the current state of the solver to the
+ * given filename.
+ *
+ * @param filename File.
+ * @param solver Solver.
+ */
+ private void printSummary(String filename,
+ TravellingSalesmanSolver solver) {
+ PrintWriter out = null;
+ try {
+ out = new PrintWriter(filename);
+ out.println(travelCoordinatesTable(solver));
+
+ final City[] result = solver.getCityList();
+ out.println("# Number of unique cities: " + uniqueCities(result).size());
+ out.println("# Travel distance: " + computeTravelDistance(result));
+ } catch (Exception e) {
+ // Do nothing.
+ } finally {
+ if (out != null) {
+ out.close();
+ }
+ }
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateActionTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateActionTest.java
new file mode 100644
index 000000000..c3e22b961
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateActionTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.UpdateAction;
+import org.apache.commons.math3.ml.neuralnet.OffsetFeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.ml.neuralnet.oned.NeuronString;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link KohonenUpdateAction} class.
+ */
+public class KohonenUpdateActionTest {
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ */
+ @Test
+ public void testUpdate() {
+ final FeatureInitializer init
+ = new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(0, 0.1));
+ final FeatureInitializer[] initArray = { init };
+
+ final int netSize = 3;
+ final Network net = new NeuronString(netSize, false, initArray).getNetwork();
+ final DistanceMeasure dist = new EuclideanDistance();
+ final LearningFactorFunction learning
+ = LearningFactorFunctionFactory.exponentialDecay(1, 0.1, 100);
+ final NeighbourhoodSizeFunction neighbourhood
+ = NeighbourhoodSizeFunctionFactory.exponentialDecay(3, 1, 100);
+ final UpdateAction update = new KohonenUpdateAction(dist, learning, neighbourhood);
+
+ // The following test ensures that, after one "update",
+ // 1. when the initial learning rate equal to 1, the best matching
+ // neuron's features are mapped to the input's features,
+ // 2. when the initial neighbourhood is larger than the network's size,
+ // all neuron's features get closer to the input's features.
+
+ final double[] features = new double[] { 0.3 };
+ final double[] distancesBefore = new double[netSize];
+ int count = 0;
+ for (Neuron n : net) {
+ distancesBefore[count++] = dist.compute(n.getFeatures(), features);
+ }
+ final Neuron bestBefore = MapUtils.findBest(features, net, dist);
+
+ // Initial distance from the best match is larger than zero.
+ Assert.assertTrue(dist.compute(bestBefore.getFeatures(), features) >= 0.2 * 0.2);
+
+ update.update(net, features);
+
+ final double[] distancesAfter = new double[netSize];
+ count = 0;
+ for (Neuron n : net) {
+ distancesAfter[count++] = dist.compute(n.getFeatures(), features);
+ }
+ final Neuron bestAfter = MapUtils.findBest(features, net, dist);
+
+ Assert.assertEquals(bestBefore, bestAfter);
+ // Distance is now zero.
+ Assert.assertEquals(0, dist.compute(bestAfter.getFeatures(), features), 0d);
+
+ for (int i = 0; i < netSize; i++) {
+ // All distances have decreased.
+ Assert.assertTrue(distancesAfter[i] < distancesBefore[i]);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactoryTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactoryTest.java
new file mode 100644
index 000000000..93df5adee
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/LearningFactorFunctionFactoryTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.OutOfRangeException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link LearningFactorFunctionFactory} class.
+ */
+public class LearningFactorFunctionFactoryTest {
+ @Test(expected=OutOfRangeException.class)
+ public void testExponentialDecayPrecondition0() {
+ LearningFactorFunctionFactory.exponentialDecay(0d, 0d, 2);
+ }
+ @Test(expected=OutOfRangeException.class)
+ public void testExponentialDecayPrecondition1() {
+ LearningFactorFunctionFactory.exponentialDecay(1 + 1e-10, 0d, 2);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testExponentialDecayPrecondition2() {
+ LearningFactorFunctionFactory.exponentialDecay(1d, 0d, 2);
+ }
+ @Test(expected=NumberIsTooLargeException.class)
+ public void testExponentialDecayPrecondition3() {
+ LearningFactorFunctionFactory.exponentialDecay(1d, 1d, 100);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testExponentialDecayPrecondition4() {
+ LearningFactorFunctionFactory.exponentialDecay(1d, 0.2, 0);
+ }
+
+ @Test
+ public void testExponentialDecayTrivial() {
+ final int n = 65;
+ final double init = 0.5;
+ final double valueAtN = 0.1;
+ final LearningFactorFunction f
+ = LearningFactorFunctionFactory.exponentialDecay(init, valueAtN, n);
+
+ Assert.assertEquals(init, f.value(0), 0d);
+ Assert.assertEquals(valueAtN, f.value(n), 0d);
+ Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
+ }
+
+ @Test(expected=OutOfRangeException.class)
+ public void testQuasiSigmoidDecayPrecondition0() {
+ LearningFactorFunctionFactory.quasiSigmoidDecay(0d, -1d, 2);
+ }
+ @Test(expected=OutOfRangeException.class)
+ public void testQuasiSigmoidDecayPrecondition1() {
+ LearningFactorFunctionFactory.quasiSigmoidDecay(1 + 1e-10, -1d, 2);
+ }
+ @Test(expected=NumberIsTooLargeException.class)
+ public void testQuasiSigmoidDecayPrecondition3() {
+ LearningFactorFunctionFactory.quasiSigmoidDecay(1d, 0d, 100);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testQuasiSigmoidDecayPrecondition4() {
+ LearningFactorFunctionFactory.quasiSigmoidDecay(1d, -1d, 0);
+ }
+
+ @Test
+ public void testQuasiSigmoidDecayTrivial() {
+ final int n = 65;
+ final double init = 0.5;
+ final double slope = -1e-1;
+ final LearningFactorFunction f
+ = LearningFactorFunctionFactory.quasiSigmoidDecay(init, slope, n);
+
+ Assert.assertEquals(init, f.value(0), 0d);
+ // Very approximate derivative.
+ Assert.assertEquals(slope, f.value(n) - f.value(n - 1), 1e-2);
+ Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactoryTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactoryTest.java
new file mode 100644
index 000000000..4570fc8af
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/NeighbourhoodSizeFunctionFactoryTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link NeighbourhoodSizeFunctionFactory} class.
+ */
+public class NeighbourhoodSizeFunctionFactoryTest {
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testExponentialDecayPrecondition1() {
+ NeighbourhoodSizeFunctionFactory.exponentialDecay(0, 0, 2);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testExponentialDecayPrecondition2() {
+ NeighbourhoodSizeFunctionFactory.exponentialDecay(1, 0, 2);
+ }
+ @Test(expected=NumberIsTooLargeException.class)
+ public void testExponentialDecayPrecondition3() {
+ NeighbourhoodSizeFunctionFactory.exponentialDecay(1, 1, 100);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testExponentialDecayPrecondition4() {
+ NeighbourhoodSizeFunctionFactory.exponentialDecay(2, 1, 0);
+ }
+
+ @Test
+ public void testExponentialDecayTrivial() {
+ final int n = 65;
+ final int init = 4;
+ final int valueAtN = 3;
+ final NeighbourhoodSizeFunction f
+ = NeighbourhoodSizeFunctionFactory.exponentialDecay(init, valueAtN, n);
+
+ Assert.assertEquals(init, f.value(0));
+ Assert.assertEquals(valueAtN, f.value(n));
+ Assert.assertEquals(0, f.value(Long.MAX_VALUE));
+ }
+
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testQuasiSigmoidDecayPrecondition1() {
+ NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(0d, -1d, 2);
+ }
+ @Test(expected=NumberIsTooLargeException.class)
+ public void testQuasiSigmoidDecayPrecondition3() {
+ NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(1d, 0d, 100);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testQuasiSigmoidDecayPrecondition4() {
+ NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(1d, -1d, 0);
+ }
+
+ @Test
+ public void testQuasiSigmoidDecayTrivial() {
+ final int n = 65;
+ final double init = 4;
+ final double slope = -1e-1;
+ final NeighbourhoodSizeFunction f
+ = NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(init, slope, n);
+
+ Assert.assertEquals(init, f.value(0), 0d);
+ Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/TravellingSalesmanSolver.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/TravellingSalesmanSolver.java
new file mode 100644
index 000000000..e851dc841
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/TravellingSalesmanSolver.java
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm;
+
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Collection;
+import java.util.Iterator;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.ml.neuralnet.oned.NeuronString;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.Well44497b;
+import org.apache.commons.math3.exception.MathUnsupportedOperationException;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.FunctionUtils;
+import org.apache.commons.math3.analysis.function.HarmonicOscillator;
+import org.apache.commons.math3.analysis.function.Constant;
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.distribution.UniformRealDistribution;
+
+/**
+ * Solves the "Travelling Salesman's Problem" (i.e. trying to find the
+ * sequence of cities that minimizes the travel distance) using a 1D
+ * SOFM.
+ */
+public class TravellingSalesmanSolver {
+ private static final long FIRST_NEURON_ID = 0;
+ /** RNG. */
+ private final RandomGenerator random = new Well44497b();
+ /** Set of cities. */
+ private final Set cities = new HashSet();
+ /** SOFM. */
+ private final Network net;
+ /** Distance function. */
+ private final DistanceMeasure distance = new EuclideanDistance();
+ /** Total number of neurons. */
+ private final int numberOfNeurons;
+
+ /**
+ * @param cityList List of cities to visit in a single travel.
+ * @param numNeuronsPerCity Number of neurons per city.
+ */
+ public TravellingSalesmanSolver(City[] cityList,
+ double numNeuronsPerCity) {
+ final double[] xRange = {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY};
+ final double[] yRange = {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY};
+
+ // Make sure that each city will appear only once in the list.
+ for (City city : cityList) {
+ cities.add(city);
+ }
+
+ // Total number of neurons.
+ numberOfNeurons = (int) numNeuronsPerCity * cities.size();
+
+ // Create a network with circle topology.
+ net = new NeuronString(numberOfNeurons, true, makeInitializers()).getNetwork();
+ }
+
+ /**
+ * Creates training tasks.
+ *
+ * @param numTasks Number of tasks to create.
+ * @param numSamplesPerTask Number of training samples per task.
+ * @return the created tasks.
+ */
+ public Runnable[] createParallelTasks(int numTasks,
+ long numSamplesPerTask) {
+ final Runnable[] tasks = new Runnable[numTasks];
+ final LearningFactorFunction learning
+ = LearningFactorFunctionFactory.exponentialDecay(2e-1,
+ 5e-2,
+ numSamplesPerTask / 2);
+ final NeighbourhoodSizeFunction neighbourhood
+ = NeighbourhoodSizeFunctionFactory.exponentialDecay(0.5 * numberOfNeurons,
+ 0.1 * numberOfNeurons,
+ numSamplesPerTask / 2);
+
+ for (int i = 0; i < numTasks; i++) {
+ final KohonenUpdateAction action = new KohonenUpdateAction(distance,
+ learning,
+ neighbourhood);
+ tasks[i] = new KohonenTrainingTask(net,
+ createRandomIterator(numSamplesPerTask),
+ action);
+ }
+
+ return tasks;
+ }
+
+ /**
+ * Creates a training task.
+ *
+ * @param numSamples Number of training samples.
+ * @return the created task.
+ */
+ public Runnable createSequentialTask(long numSamples) {
+ return createParallelTasks(1, numSamples)[0];
+ }
+
+ /**
+ * Creates an iterator that will present a series of city's coordinates in
+ * a random order.
+ *
+ * @param numSamples Number of samples.
+ * @return the iterator.
+ */
+ private Iterator createRandomIterator(final long numSamples) {
+ final List cityList = new ArrayList();
+ cityList.addAll(cities);
+
+ return new Iterator() {
+ /** Number of samples. */
+ private long n = 0;
+ /** {@inheritDoc} */
+ public boolean hasNext() {
+ return n < numSamples;
+ }
+ /** {@inheritDoc} */
+ public double[] next() {
+ ++n;
+ return cityList.get(random.nextInt(cityList.size())).getCoordinates();
+ }
+ /** {@inheritDoc} */
+ public void remove() {
+ throw new MathUnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * @return the list of linked neurons (i.e. the one-dimensional
+ * SOFM).
+ */
+ private List getNeuronList() {
+ // Sequence of coordinates.
+ final List list = new ArrayList();
+
+ // First neuron.
+ Neuron current = net.getNeuron(FIRST_NEURON_ID);
+ while (true) {
+ list.add(current);
+ final Collection neighbours
+ = net.getNeighbours(current, list);
+
+ final Iterator iter = neighbours.iterator();
+ if (!iter.hasNext()) {
+ // All neurons have been visited.
+ break;
+ }
+
+ current = iter.next();
+ }
+
+ return list;
+ }
+
+ /**
+ * @return the list of features (coordinates) of linked neurons.
+ */
+ public List getCoordinatesList() {
+ // Sequence of coordinates.
+ final List coordinatesList = new ArrayList();
+
+ for (Neuron n : getNeuronList()) {
+ coordinatesList.add(n.getFeatures());
+ }
+
+ return coordinatesList;
+ }
+
+ /**
+ * Returns the travel proposed by the solver.
+ * Note: cities can be missing or duplicated.
+ *
+ * @return the list of cities in travel order.
+ */
+ public City[] getCityList() {
+ final List coord = getCoordinatesList();
+ final City[] cityList = new City[coord.size()];
+ for (int i = 0; i < cityList.length; i++) {
+ final double[] c = coord.get(i);
+ cityList[i] = getClosestCity(c[0], c[1]);
+ }
+ return cityList;
+ }
+
+ /**
+ * @param x x-coordinate.
+ * @param y y-coordinate.
+ * @return the city whose coordinates are closest to {@code (x, y)}.
+ */
+ public City getClosestCity(double x,
+ double y) {
+ City closest = null;
+ double min = Double.POSITIVE_INFINITY;
+ for (City c : cities) {
+ final double d = c.distance(x, y);
+ if (d < min) {
+ min = d;
+ closest = c;
+ }
+ }
+ return closest;
+ }
+
+ /**
+ * Computes the barycentre of all city locations.
+ *
+ * @param cities City list.
+ * @return the barycentre.
+ */
+ private static double[] barycentre(Set cities) {
+ double xB = 0;
+ double yB = 0;
+
+ int count = 0;
+ for (City c : cities) {
+ final double[] coord = c.getCoordinates();
+ xB += coord[0];
+ yB += coord[1];
+
+ ++count;
+ }
+
+ return new double[] { xB / count, yB / count };
+ }
+
+ /**
+ * Computes the largest distance between the point at coordinates
+ * {@code (x, y)} and any of the cities.
+ *
+ * @param x x-coodinate.
+ * @param y y-coodinate.
+ * @param cities City list.
+ * @return the largest distance.
+ */
+ private static double largestDistance(double x,
+ double y,
+ Set cities) {
+ double maxDist = 0;
+ for (City c : cities) {
+ final double dist = c.distance(x, y);
+ if (dist > maxDist) {
+ maxDist = dist;
+ }
+ }
+
+ return maxDist;
+ }
+
+ /**
+ * Creates the features' initializers: an approximate circle around the
+ * barycentre of the cities.
+ *
+ * @return an array containing the two initializers.
+ */
+ private FeatureInitializer[] makeInitializers() {
+ // Barycentre.
+ final double[] centre = barycentre(cities);
+ // Largest distance from centre.
+ final double radius = 0.5 * largestDistance(centre[0], centre[1], cities);
+
+ final double omega = 2 * Math.PI / numberOfNeurons;
+ final UnivariateFunction h1 = new HarmonicOscillator(radius, omega, 0);
+ final UnivariateFunction h2 = new HarmonicOscillator(radius, omega, 0.5 * Math.PI);
+
+ final UnivariateFunction f1 = FunctionUtils.add(h1, new Constant(centre[0]));
+ final UnivariateFunction f2 = FunctionUtils.add(h2, new Constant(centre[1]));
+
+ final RealDistribution u = new UniformRealDistribution(-0.05 * radius, 0.05 * radius);
+
+ return new FeatureInitializer[] {
+ FeatureInitializerFactory.randomize(u, FeatureInitializerFactory.function(f1, 0, 1)),
+ FeatureInitializerFactory.randomize(u, FeatureInitializerFactory.function(f2, 0, 1))
+ };
+ }
+}
+
+/**
+ * A city, represented by a name and two-dimensional coordinates.
+ */
+class City {
+ /** Identifier. */
+ final String name;
+ /** x-coordinate. */
+ final double x;
+ /** y-coordinate. */
+ final double y;
+
+ /**
+ * @param name Name.
+ * @param x Cartesian x-coordinate.
+ * @param y Cartesian y-coordinate.
+ */
+ public City(String name,
+ double x,
+ double y) {
+ this.name = name;
+ this.x = x;
+ this.y = y;
+ }
+
+ /**
+ * @retun the name.
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * @return the (x, y) coordinates.
+ */
+ public double[] getCoordinates() {
+ return new double[] { x, y };
+ }
+
+ /**
+ * Computes the distance between this city and
+ * the given point.
+ *
+ * @param x x-coodinate.
+ * @param y y-coodinate.
+ * @return the distance between {@code (x, y)} and this
+ * city.
+ */
+ public double distance(double x,
+ double y) {
+ final double xDiff = this.x - x;
+ final double yDiff = this.y - y;
+
+ return FastMath.sqrt(xDiff * xDiff + yDiff * yDiff);
+ }
+
+ /** {@inheritDoc} */
+ public boolean equals(Object o) {
+ if (o instanceof City) {
+ final City other = (City) o;
+ return x == other.x &&
+ y == other.y;
+ }
+ return false;
+ }
+
+ /** {@inheritDoc} */
+ public int hashCode() {
+ int result = 17;
+
+ final long c1 = Double.doubleToLongBits(x);
+ result = 31 * result + (int) (c1 ^ (c1 >>> 32));
+
+ final long c2 = Double.doubleToLongBits(y);
+ result = 31 * result + (int) (c2 ^ (c2 >>> 32));
+
+ return result;
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunctionTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunctionTest.java
new file mode 100644
index 000000000..ddbdcfcdb
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/util/ExponentialDecayFunctionTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link ExponentialDecayFunction} class
+ */
+public class ExponentialDecayFunctionTest {
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testPrecondition1() {
+ new ExponentialDecayFunction(0d, 0d, 2);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testPrecondition2() {
+ new ExponentialDecayFunction(1d, 0d, 2);
+ }
+ @Test(expected=NumberIsTooLargeException.class)
+ public void testPrecondition3() {
+ new ExponentialDecayFunction(1d, 1d, 100);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testPrecondition4() {
+ new ExponentialDecayFunction(1d, 0.2, 0);
+ }
+
+ @Test
+ public void testTrivial() {
+ final int n = 65;
+ final double init = 4;
+ final double valueAtN = 3;
+ final ExponentialDecayFunction f = new ExponentialDecayFunction(init, valueAtN, n);
+
+ Assert.assertEquals(init, f.value(0), 0d);
+ Assert.assertEquals(valueAtN, f.value(n), 0d);
+ Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunctionTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunctionTest.java
new file mode 100644
index 000000000..49c9cda6a
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/sofm/util/QuasiSigmoidDecayFunctionTest.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.sofm.util;
+
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.junit.Test;
+import org.junit.Assert;
+
+/**
+ * Tests for {@link QuasiSigmoidDecayFunction} class
+ */
+public class QuasiSigmoidDecayFunctionTest {
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testPrecondition1() {
+ new QuasiSigmoidDecayFunction(0d, -1d, 2);
+ }
+ @Test(expected=NumberIsTooLargeException.class)
+ public void testPrecondition3() {
+ new QuasiSigmoidDecayFunction(1d, 0d, 100);
+ }
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testPrecondition4() {
+ new QuasiSigmoidDecayFunction(1d, -1d, 0);
+ }
+
+ @Test
+ public void testTrivial() {
+ final int n = 65;
+ final double init = 4;
+ final double slope = -1e-1;
+ final QuasiSigmoidDecayFunction f = new QuasiSigmoidDecayFunction(init, slope, n);
+
+ Assert.assertEquals(init, f.value(0), 0d);
+ // Very approximate derivative.
+ Assert.assertEquals(slope, f.value(n + 1) - f.value(n), 1e-4);
+ Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
+ }
+}
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2DTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
new file mode 100644
index 000000000..1067dcd7c
--- /dev/null
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
@@ -0,0 +1,685 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.ml.neuralnet.twod;
+
+import java.io.ByteArrayOutputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Collection;
+import org.junit.Test;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
+import org.apache.commons.math3.ml.neuralnet.Network;
+import org.apache.commons.math3.ml.neuralnet.Neuron;
+import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+
+/**
+ * Tests for {@link NeuronSquareMesh2D} and {@link Network} functionality for
+ * a two-dimensional network.
+ */
+public class NeuronSquareMesh2DTest {
+ final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
+
+ @Test(expected=NumberIsTooSmallException.class)
+ public void testMinimalNetworkSize1() {
+ final FeatureInitializer[] initArray = { init };
+
+ new NeuronSquareMesh2D(1, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray);
+ }
+
+ @Test(expected=NumberIsTooSmallException.class)
+ public void testMinimalNetworkSize2() {
+ final FeatureInitializer[] initArray = { init };
+
+ new NeuronSquareMesh2D(2, false,
+ 0, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray);
+ }
+
+ @Test
+ public void testGetFeaturesSize() {
+ final FeatureInitializer[] initArray = { init, init, init };
+
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+ Assert.assertEquals(3, net.getFeaturesSize());
+ }
+
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1
+ * | |
+ * | |
+ * 2-----3
+ */
+ @Test
+ public void test2x2Network() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // Neurons 0 and 3.
+ for (long id : new long[] { 0, 3 }) {
+ neighbours = net.getNeighbours(net.getNeuron(id));
+ for (long nId : new long[] { 1, 2 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+ }
+
+ // Neurons 1 and 2.
+ for (long id : new long[] { 1, 2 }) {
+ neighbours = net.getNeighbours(net.getNeuron(id));
+ for (long nId : new long[] { 0, 3 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(2, neighbours.size());
+ }
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1
+ * | |
+ * | |
+ * 2-----3
+ */
+ @Test
+ public void test2x2Network2() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.MOORE,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // All neurons
+ for (long id : new long[] { 0, 1, 2, 3 }) {
+ neighbours = net.getNeighbours(net.getNeuron(id));
+ for (long nId : new long[] { 0, 1, 2, 3 }) {
+ if (id != nId) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ }
+ }
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ * | | |
+ * | | |
+ * 3-----4-----5
+ */
+ @Test
+ public void test3x2CylinderNetwork() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 3, true,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // Neuron 0.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ for (long nId : new long[] { 1, 2, 3 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 1.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ for (long nId : new long[] { 0, 2, 4 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 2.
+ neighbours = net.getNeighbours(net.getNeuron(2));
+ for (long nId : new long[] { 0, 1, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 3.
+ neighbours = net.getNeighbours(net.getNeuron(3));
+ for (long nId : new long[] { 0, 4, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 4.
+ neighbours = net.getNeighbours(net.getNeuron(4));
+ for (long nId : new long[] { 1, 3, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 5.
+ neighbours = net.getNeighbours(net.getNeuron(5));
+ for (long nId : new long[] { 2, 3, 4 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ * | | |
+ * | | |
+ * 3-----4-----5
+ */
+ @Test
+ public void test3x2CylinderNetwork2() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 3, true,
+ SquareNeighbourhood.MOORE,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // All neurons.
+ for (long id : new long[] { 0, 1, 2, 3, 4, 5 }) {
+ neighbours = net.getNeighbours(net.getNeuron(id));
+ for (long nId : new long[] { 0, 1, 2, 3, 4, 5 }) {
+ if (id != nId) {
+ Assert.assertTrue("id=" + id + " nId=" + nId,
+ neighbours.contains(net.getNeuron(nId)));
+ }
+ }
+ }
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ * | | |
+ * | | |
+ * 3-----4-----5
+ * | | |
+ * | | |
+ * 6-----7-----8
+ */
+ @Test
+ public void test3x3TorusNetwork() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(3, true,
+ 3, true,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // Neuron 0.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ for (long nId : new long[] { 1, 2, 3, 6 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 1.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ for (long nId : new long[] { 0, 2, 4, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 2.
+ neighbours = net.getNeighbours(net.getNeuron(2));
+ for (long nId : new long[] { 0, 1, 5, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 3.
+ neighbours = net.getNeighbours(net.getNeuron(3));
+ for (long nId : new long[] { 0, 4, 5, 6 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 4.
+ neighbours = net.getNeighbours(net.getNeuron(4));
+ for (long nId : new long[] { 1, 3, 5, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 5.
+ neighbours = net.getNeighbours(net.getNeuron(5));
+ for (long nId : new long[] { 2, 3, 4, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 6.
+ neighbours = net.getNeighbours(net.getNeuron(6));
+ for (long nId : new long[] { 0, 3, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 7.
+ neighbours = net.getNeighbours(net.getNeuron(7));
+ for (long nId : new long[] { 1, 4, 6, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // Neuron 8.
+ neighbours = net.getNeighbours(net.getNeuron(8));
+ for (long nId : new long[] { 2, 5, 6, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ * | | |
+ * | | |
+ * 3-----4-----5
+ * | | |
+ * | | |
+ * 6-----7-----8
+ */
+ @Test
+ public void test3x3TorusNetwork2() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(3, true,
+ 3, true,
+ SquareNeighbourhood.MOORE,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // All neurons.
+ for (long id : new long[] { 0, 1, 2, 3, 4, 5, 6, 7, 8 }) {
+ neighbours = net.getNeighbours(net.getNeuron(id));
+ for (long nId : new long[] { 0, 1, 2, 3, 4, 5, 6, 7, 8 }) {
+ if (id != nId) {
+ Assert.assertTrue("id=" + id + " nId=" + nId,
+ neighbours.contains(net.getNeuron(nId)));
+ }
+ }
+ }
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ * | | |
+ * | | |
+ * 3-----4-----5
+ * | | |
+ * | | |
+ * 6-----7-----8
+ */
+ @Test
+ public void test3x3CylinderNetwork() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(3, false,
+ 3, true,
+ SquareNeighbourhood.MOORE,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // Neuron 0.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ for (long nId : new long[] { 1, 2, 3, 4, 5}) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 1.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ for (long nId : new long[] { 0, 2, 3, 4, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 2.
+ neighbours = net.getNeighbours(net.getNeuron(2));
+ for (long nId : new long[] { 0, 1, 3, 4, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 3.
+ neighbours = net.getNeighbours(net.getNeuron(3));
+ for (long nId : new long[] { 0, 1, 2, 4, 5, 6, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(8, neighbours.size());
+
+ // Neuron 4.
+ neighbours = net.getNeighbours(net.getNeuron(4));
+ for (long nId : new long[] { 0, 1, 2, 3, 5, 6, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(8, neighbours.size());
+
+ // Neuron 5.
+ neighbours = net.getNeighbours(net.getNeuron(5));
+ for (long nId : new long[] { 0, 1, 2, 3, 4, 6, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(8, neighbours.size());
+
+ // Neuron 6.
+ neighbours = net.getNeighbours(net.getNeuron(6));
+ for (long nId : new long[] { 3, 4, 5, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 7.
+ neighbours = net.getNeighbours(net.getNeuron(7));
+ for (long nId : new long[] { 3, 4, 5, 6, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 8.
+ neighbours = net.getNeighbours(net.getNeuron(8));
+ for (long nId : new long[] { 3, 4, 5, 6, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2
+ * | | |
+ * | | |
+ * 3-----4-----5
+ * | | |
+ * | | |
+ * 6-----7-----8
+ */
+ @Test
+ public void test3x3CylinderNetwork2() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(3, false,
+ 3, false,
+ SquareNeighbourhood.MOORE,
+ initArray).getNetwork();
+ Collection neighbours;
+
+ // Neuron 0.
+ neighbours = net.getNeighbours(net.getNeuron(0));
+ for (long nId : new long[] { 1, 3, 4}) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 1.
+ neighbours = net.getNeighbours(net.getNeuron(1));
+ for (long nId : new long[] { 0, 2, 3, 4, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 2.
+ neighbours = net.getNeighbours(net.getNeuron(2));
+ for (long nId : new long[] { 1, 4, 5 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 3.
+ neighbours = net.getNeighbours(net.getNeuron(3));
+ for (long nId : new long[] { 0, 1, 4, 6, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 4.
+ neighbours = net.getNeighbours(net.getNeuron(4));
+ for (long nId : new long[] { 0, 1, 2, 3, 5, 6, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(8, neighbours.size());
+
+ // Neuron 5.
+ neighbours = net.getNeighbours(net.getNeuron(5));
+ for (long nId : new long[] { 1, 2, 4, 7, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 6.
+ neighbours = net.getNeighbours(net.getNeuron(6));
+ for (long nId : new long[] { 3, 4, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+
+ // Neuron 7.
+ neighbours = net.getNeighbours(net.getNeuron(7));
+ for (long nId : new long[] { 3, 4, 5, 6, 8 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(5, neighbours.size());
+
+ // Neuron 8.
+ neighbours = net.getNeighbours(net.getNeuron(8));
+ for (long nId : new long[] { 4, 5, 7 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(3, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2-----3-----4
+ * | | | | |
+ * | | | | |
+ * 5-----6-----7-----8-----9
+ * | | | | |
+ * | | | | |
+ * 10----11----12----13---14
+ * | | | | |
+ * | | | | |
+ * 15----16----17----18---19
+ * | | | | |
+ * | | | | |
+ * 20----21----22----23---24
+ */
+ @Test
+ public void testConcentricNeighbourhood() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(5, true,
+ 5, true,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+
+ Collection neighbours;
+ Collection exclude = new HashSet();
+
+ // Level-1 neighbourhood.
+ neighbours = net.getNeighbours(net.getNeuron(12));
+ for (long nId : new long[] { 7, 11, 13, 17 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(4, neighbours.size());
+
+ // 1. Add the neuron to the "exclude" list.
+ exclude.add(net.getNeuron(12));
+ // 2. Add all neurons from level-1 neighbourhood.
+ exclude.addAll(neighbours);
+ // 3. Retrieve level-2 neighbourhood.
+ neighbours = net.getNeighbours(neighbours, exclude);
+ for (long nId : new long[] { 6, 8, 16, 18, 2, 10, 14, 22 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(8, neighbours.size());
+ }
+
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1-----2-----3-----4
+ * | | | | |
+ * | | | | |
+ * 5-----6-----7-----8-----9
+ * | | | | |
+ * | | | | |
+ * 10----11----12----13---14
+ * | | | | |
+ * | | | | |
+ * 15----16----17----18---19
+ * | | | | |
+ * | | | | |
+ * 20----21----22----23---24
+ */
+ @Test
+ public void testConcentricNeighbourhood2() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(5, true,
+ 5, true,
+ SquareNeighbourhood.MOORE,
+ initArray).getNetwork();
+
+ Collection neighbours;
+ Collection exclude = new HashSet();
+
+ // Level-1 neighbourhood.
+ neighbours = net.getNeighbours(net.getNeuron(8));
+ for (long nId : new long[] { 2, 3, 4, 7, 9, 12, 13, 14 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(8, neighbours.size());
+
+ // 1. Add the neuron to the "exclude" list.
+ exclude.add(net.getNeuron(8));
+ // 2. Add all neurons from level-1 neighbourhood.
+ exclude.addAll(neighbours);
+ // 3. Retrieve level-2 neighbourhood.
+ neighbours = net.getNeighbours(neighbours, exclude);
+ for (long nId : new long[] { 1, 6, 11, 16, 17, 18, 19, 15, 10, 5, 0, 20, 24, 23, 22, 21 }) {
+ Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
+ }
+ // Ensures that no other neurons is in the neihbourhood set.
+ Assert.assertEquals(16, neighbours.size());
+ }
+
+ @Test
+ public void testSerialize()
+ throws IOException,
+ ClassNotFoundException {
+ final FeatureInitializer[] initArray = { init };
+ final NeuronSquareMesh2D out = new NeuronSquareMesh2D(4, false,
+ 3, true,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray);
+
+ final ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ final ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(out);
+
+ final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ final ObjectInputStream ois = new ObjectInputStream(bis);
+ final NeuronSquareMesh2D in = (NeuronSquareMesh2D) ois.readObject();
+
+ for (Neuron nOut : out.getNetwork()) {
+ final Neuron nIn = in.getNetwork().getNeuron(nOut.getIdentifier());
+
+ // Same values.
+ final double[] outF = nOut.getFeatures();
+ final double[] inF = nIn.getFeatures();
+ Assert.assertEquals(outF.length, inF.length);
+ for (int i = 0; i < outF.length; i++) {
+ Assert.assertEquals(outF[i], inF[i], 0d);
+ }
+
+ // Same neighbours.
+ final Collection outNeighbours = out.getNetwork().getNeighbours(nOut);
+ final Collection inNeighbours = in.getNetwork().getNeighbours(nIn);
+ Assert.assertEquals(outNeighbours.size(), inNeighbours.size());
+ for (Neuron oN : outNeighbours) {
+ Assert.assertTrue(inNeighbours.contains(in.getNetwork().getNeuron(oN.getIdentifier())));
+ }
+ }
+ }
+}
diff --git a/src/userguide/java/org/apache/commons/math3/userguide/sofm/ChineseRings.java b/src/userguide/java/org/apache/commons/math3/userguide/sofm/ChineseRings.java
new file mode 100644
index 000000000..f99b26593
--- /dev/null
+++ b/src/userguide/java/org/apache/commons/math3/userguide/sofm/ChineseRings.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.userguide.sofm;
+
+import java.util.Iterator;
+import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
+import org.apache.commons.math3.geometry.euclidean.threed.Rotation;
+import org.apache.commons.math3.random.UnitSphereRandomVectorGenerator;
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.distribution.UniformRealDistribution;
+
+/**
+ * Class that creates two intertwined rings.
+ * Each ring is composed of a cloud of points.
+ */
+public class ChineseRings {
+ /** Points in the two rings. */
+ private final Vector3D[] points;
+
+ /**
+ * @param orientationRing1 Vector othogonal to the plane containing the
+ * first ring.
+ * @param radiusRing1 Radius of the first ring.
+ * @param halfWidthRing1 Half-width of the first ring.
+ * @param radiusRing2 Radius of the second ring.
+ * @param halfWidthRing2 Half-width of the second ring.
+ * @param numPointsRing1 Number of points in the first ring.
+ * @param numPointsRing2 Number of points in the second ring.
+ */
+ public ChineseRings(Vector3D orientationRing1,
+ double radiusRing1,
+ double halfWidthRing1,
+ double radiusRing2,
+ double halfWidthRing2,
+ int numPointsRing1,
+ int numPointsRing2) {
+ // First ring (centered at the origin).
+ final Vector3D[] firstRing = new Vector3D[numPointsRing1];
+ // Second ring (centered around the first ring).
+ final Vector3D[] secondRing = new Vector3D[numPointsRing2];
+
+ // Create two rings lying in xy-plane.
+ final UnitSphereRandomVectorGenerator unit
+ = new UnitSphereRandomVectorGenerator(2);
+
+ final RealDistribution radius1
+ = new UniformRealDistribution(radiusRing1 - halfWidthRing1,
+ radiusRing1 + halfWidthRing1);
+ final RealDistribution widthRing1
+ = new UniformRealDistribution(-halfWidthRing1, halfWidthRing1);
+
+ for (int i = 0; i < numPointsRing1; i++) {
+ final double[] v = unit.nextVector();
+ final double r = radius1.sample();
+ // First ring is in the xy-plane, centered at (0, 0, 0).
+ firstRing[i] = new Vector3D(v[0] * r,
+ v[1] * r,
+ widthRing1.sample());
+ }
+
+ final RealDistribution radius2
+ = new UniformRealDistribution(radiusRing2 - halfWidthRing2,
+ radiusRing2 + halfWidthRing2);
+ final RealDistribution widthRing2
+ = new UniformRealDistribution(-halfWidthRing2, halfWidthRing2);
+
+ for (int i = 0; i < numPointsRing2; i++) {
+ final double[] v = unit.nextVector();
+ final double r = radius2.sample();
+ // Second ring is in the xz-plane, centered at (radiusRing1, 0, 0).
+ secondRing[i] = new Vector3D(radiusRing1 + v[0] * r,
+ widthRing2.sample(),
+ v[1] * r);
+ }
+
+ // Move first and second rings into position.
+ final Rotation rot = new Rotation(Vector3D.PLUS_K,
+ orientationRing1.normalize());
+ int count = 0;
+ points = new Vector3D[numPointsRing1 + numPointsRing2];
+ for (int i = 0; i < numPointsRing1; i++) {
+ points[count++] = rot.applyTo(firstRing[i]);
+ }
+ for (int i = 0; i < numPointsRing2; i++) {
+ points[count++] = rot.applyTo(secondRing[i]);
+ }
+ }
+
+ /**
+ * Gets all the points.
+ */
+ public Vector3D[] getPoints() {
+ return points.clone();
+ }
+}
diff --git a/src/userguide/java/org/apache/commons/math3/userguide/sofm/ChineseRingsClassifier.java b/src/userguide/java/org/apache/commons/math3/userguide/sofm/ChineseRingsClassifier.java
new file mode 100644
index 000000000..d27d0d9c4
--- /dev/null
+++ b/src/userguide/java/org/apache/commons/math3/userguide/sofm/ChineseRingsClassifier.java
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.userguide.sofm;
+
+import java.util.Iterator;
+import java.io.PrintWriter;
+import java.io.IOException;
+import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
+import org.apache.commons.math3.ml.neuralnet.MapUtils;
+import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math3.ml.neuralnet.sofm.LearningFactorFunction;
+import org.apache.commons.math3.ml.neuralnet.sofm.LearningFactorFunctionFactory;
+import org.apache.commons.math3.ml.neuralnet.sofm.NeighbourhoodSizeFunction;
+import org.apache.commons.math3.ml.neuralnet.sofm.NeighbourhoodSizeFunctionFactory;
+import org.apache.commons.math3.ml.neuralnet.sofm.KohonenUpdateAction;
+import org.apache.commons.math3.ml.neuralnet.sofm.KohonenTrainingTask;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.Well19937c;
+import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
+import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.exception.MathUnsupportedOperationException;
+
+/**
+ * SOFM for categorizing points that belong to each of two intertwined rings.
+ *
+ * The output currently consists in 3 text files:
+ *
+ * - "before.chinese.U.seq.dat": U-matrix of the SOFM before training
+ * - "after.chinese.U.seq.dat": U-matrix of the SOFM after training
+ * - "after.chinese.hit.seq.dat": Hit histogram after training
+ *
+ */
+public class ChineseRingsClassifier {
+ /** SOFM. */
+ private final NeuronSquareMesh2D sofm;
+ /** Rings. */
+ private final ChineseRings rings;
+ /** Distance function. */
+ private final DistanceMeasure distance = new EuclideanDistance();
+
+ public static void main(String[] args) {
+ final ChineseRings rings = new ChineseRings(new Vector3D(1, 2, 3),
+ 25, 2,
+ 20, 1,
+ 2000, 1500);
+ final ChineseRingsClassifier classifier = new ChineseRingsClassifier(rings, 15, 15);
+ printU("before.chinese.U.seq.dat", classifier);
+ classifier.createSequentialTask(100000).run();
+ printU("after.chinese.U.seq.dat", classifier);
+ printHit("after.chinese.hit.seq.dat", classifier);
+ }
+
+ /**
+ * @param rings Training data.
+ * @param dim1 Number of rows of the SOFM.
+ * @param dim2 Number of columns of the SOFM.
+ */
+ public ChineseRingsClassifier(ChineseRings rings,
+ int dim1,
+ int dim2) {
+ this.rings = rings;
+ sofm = new NeuronSquareMesh2D(dim1, false,
+ dim2, false,
+ SquareNeighbourhood.MOORE,
+ makeInitializers());
+ }
+
+ /**
+ * Creates training tasks.
+ *
+ * @param numTasks Number of tasks to create.
+ * @param numSamplesPerTask Number of training samples per task.
+ * @return the created tasks.
+ */
+ public Runnable[] createParallelTasks(int numTasks,
+ long numSamplesPerTask) {
+ final Runnable[] tasks = new Runnable[numTasks];
+ final LearningFactorFunction learning
+ = LearningFactorFunctionFactory.exponentialDecay(1e-1,
+ 5e-2,
+ numSamplesPerTask / 2);
+ final double numNeurons = FastMath.sqrt(sofm.getNumberOfRows() * sofm.getNumberOfColumns());
+ final NeighbourhoodSizeFunction neighbourhood
+ = NeighbourhoodSizeFunctionFactory.exponentialDecay(0.5 * numNeurons,
+ 0.2 * numNeurons,
+ numSamplesPerTask / 2);
+
+ for (int i = 0; i < numTasks; i++) {
+ final KohonenUpdateAction action = new KohonenUpdateAction(distance,
+ learning,
+ neighbourhood);
+ tasks[i] = new KohonenTrainingTask(sofm.getNetwork(),
+ createRandomIterator(numSamplesPerTask),
+ action);
+ }
+
+ return tasks;
+ }
+
+ /**
+ * Creates a training task.
+ *
+ * @param numSamples Number of training samples.
+ * @return the created task.
+ */
+ public Runnable createSequentialTask(long numSamples) {
+ return createParallelTasks(1, numSamples)[0];
+ }
+
+ /**
+ * Computes the U-matrix.
+ *
+ * @return the U-matrix of the network.
+ */
+ public double[][] computeU() {
+ return MapUtils.computeU(sofm, distance);
+ }
+
+ /**
+ * Computes the hit histogram.
+ *
+ * @return the histogram.
+ */
+ public int[][] computeHitHistogram() {
+ return MapUtils.computeHitHistogram(createIterable(),
+ sofm,
+ distance);
+ }
+
+ /**
+ * Computes the quantization error.
+ *
+ * @return the quantization error.
+ */
+ public double computeQuantizationError() {
+ return MapUtils.computeQuantizationError(createIterable(),
+ sofm.getNetwork(),
+ distance);
+ }
+
+ /**
+ * Computes the topographic error.
+ *
+ * @return the topographic error.
+ */
+ public double computeTopographicError() {
+ return MapUtils.computeTopographicError(createIterable(),
+ sofm.getNetwork(),
+ distance);
+ }
+
+ /**
+ * Creates the features' initializers.
+ * They are sampled from a uniform distribution around the barycentre of
+ * the rings.
+ *
+ * @return an array containing the initializers for the x, y and
+ * z coordinates of the features array of the neurons.
+ */
+ private FeatureInitializer[] makeInitializers() {
+ final SummaryStatistics[] centre = new SummaryStatistics[] {
+ new SummaryStatistics(),
+ new SummaryStatistics(),
+ new SummaryStatistics()
+ };
+ for (Vector3D p : rings.getPoints()) {
+ centre[0].addValue(p.getX());
+ centre[1].addValue(p.getY());
+ centre[2].addValue(p.getZ());
+ }
+
+ final double[] mean = new double[] {
+ centre[0].getMean(),
+ centre[1].getMean(),
+ centre[2].getMean()
+ };
+ final double s = 0.1;
+ final double[] dev = new double[] {
+ s * centre[0].getStandardDeviation(),
+ s * centre[1].getStandardDeviation(),
+ s * centre[2].getStandardDeviation()
+ };
+
+ return new FeatureInitializer[] {
+ FeatureInitializerFactory.uniform(mean[0] - dev[0], mean[0] + dev[0]),
+ FeatureInitializerFactory.uniform(mean[1] - dev[1], mean[1] + dev[1]),
+ FeatureInitializerFactory.uniform(mean[2] - dev[2], mean[2] + dev[2])
+ };
+ }
+
+ /**
+ * Creates an iterable that will present the points coordinates.
+ *
+ * @return the iterable.
+ */
+ private Iterable createIterable() {
+ return new Iterable() {
+ public Iterator iterator() {
+ return new Iterator() {
+ /** Data. */
+ final Vector3D[] points = rings.getPoints();
+ /** Number of samples. */
+ private int n = 0;
+
+ /** {@inheritDoc} */
+ public boolean hasNext() {
+ return n < points.length;
+ }
+
+ /** {@inheritDoc} */
+ public double[] next() {
+ return points[n++].toArray();
+ }
+
+ /** {@inheritDoc} */
+ public void remove() {
+ throw new MathUnsupportedOperationException();
+ }
+ };
+ }
+ };
+ }
+
+ /**
+ * Creates an iterator that will present a series of points coordinates in
+ * a random order.
+ *
+ * @param numSamples Number of samples.
+ * @return the iterator.
+ */
+ private Iterator createRandomIterator(final long numSamples) {
+ return new Iterator() {
+ /** Data. */
+ final Vector3D[] points = rings.getPoints();
+ /** RNG. */
+ final RandomGenerator rng = new Well19937c();
+ /** Number of samples. */
+ private long n = 0;
+
+ /** {@inheritDoc} */
+ public boolean hasNext() {
+ return n < numSamples;
+ }
+
+ /** {@inheritDoc} */
+ public double[] next() {
+ ++n;
+ return points[rng.nextInt(points.length)].toArray();
+ }
+
+ /** {@inheritDoc} */
+ public void remove() {
+ throw new MathUnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Prints the U-matrix of the map to the given filename.
+ *
+ * @param filename File.
+ * @param sofm Classifier.
+ */
+ private static void printU(String filename,
+ ChineseRingsClassifier sofm) {
+ PrintWriter out = null;
+ try {
+ out = new PrintWriter(filename);
+
+ final double[][] uMatrix = sofm.computeU();
+ for (int i = 0; i < uMatrix.length; i++) {
+ for (int j = 0; j < uMatrix[0].length; j++) {
+ out.print(uMatrix[i][j] + " ");
+ }
+ out.println();
+ }
+ out.println("# Quantization error: " + sofm.computeQuantizationError());
+ out.println("# Topographic error: " + sofm.computeTopographicError());
+ } catch (IOException e) {
+ // Do nothing.
+ } finally {
+ if (out != null) {
+ out.close();
+ }
+ }
+ }
+
+ /**
+ * Prints the hit histogram of the map to the given filename.
+ *
+ * @param filename File.
+ * @param sofm Classifier.
+ */
+ private static void printHit(String filename,
+ ChineseRingsClassifier sofm) {
+ PrintWriter out = null;
+ try {
+ out = new PrintWriter(filename);
+
+ final int[][] histo = sofm.computeHitHistogram();
+ for (int i = 0; i < histo.length; i++) {
+ for (int j = 0; j < histo[0].length; j++) {
+ out.print(histo[i][j] + " ");
+ }
+ out.println();
+ }
+ } catch (IOException e) {
+ // Do nothing.
+ } finally {
+ if (out != null) {
+ out.close();
+ }
+ }
+ }
+}