MATH-923
Implementation of Kohonen's Self-Organizing Feature Map (SOFM). New package "o.a.c.m.ml.neuralnet" contains base functionality for implementing different map types, i.e. methods that project a high-dimensional space onto one with a low dimension (typically 1D or 2D). The SOFM-specific code is in "o.a.c.m.ml.neuralnet.sofm". git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1557267 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
0a75cbc380
commit
aad194a346
|
@ -434,3 +434,8 @@ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
===============================================================================
|
||||
|
||||
The initial commit of package "org.apache.commons.math3.ml.neuralnet" is
|
||||
an adapted version of code developed in the context of the Data Processing
|
||||
and Analysis Consortium (DPAC) of the "Gaia" project of the European Space
|
||||
Agency (ESA).
|
||||
===============================================================================
|
||||
|
|
|
@ -363,4 +363,11 @@
|
|||
<Bug pattern="ICAST_IDIV_CAST_TO_DOUBLE" />
|
||||
</Match>
|
||||
|
||||
<!-- The following switch fall-through is intended. -->
|
||||
<Match>
|
||||
<Class name="org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D" />
|
||||
<Method name="createLinks" />
|
||||
<Bug pattern="SF_SWITCH_FALLTHROUGH" />
|
||||
</Match>
|
||||
|
||||
</FindBugsFilter>
|
||||
|
|
|
@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</properties>
|
||||
<body>
|
||||
<release version="3.3" date="TBD" description="TBD">
|
||||
<action dev="erans" type="add" issue="MATH-923">
|
||||
Utilities for creating artificial neural networks (package "o.a.c.m.ml.neuralnet").
|
||||
Implementation of Kohonen's Self-Organizing Feature Map (SOFM).
|
||||
</action>
|
||||
<action dev="tn" type="fix" issue="MATH-1082">
|
||||
The cutOff mechanism of the "SimplexSolver" in package o.a.c.math3.optim.linear
|
||||
could lead to invalid solutions. The mechanism has been improved in a way that
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
/**
|
||||
* Defines how to assign the first value of a neuron's feature.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public interface FeatureInitializer {
|
||||
/**
|
||||
* Selects the initial value.
|
||||
*
|
||||
* @return the initial value.
|
||||
*/
|
||||
double value();
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.function.Constant;
|
||||
|
||||
/**
|
||||
* Creates functions that will select the initial values of a neuron's
|
||||
* features.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class FeatureInitializerFactory {
|
||||
/** Class contains only static methods. */
|
||||
private FeatureInitializerFactory() {}
|
||||
|
||||
/**
|
||||
* Uniform sampling of the given range.
|
||||
*
|
||||
* @param min Lower bound of the range.
|
||||
* @param max Upper bound of the range.
|
||||
* @return an initializer such that the features will be initialized with
|
||||
* values within the given range.
|
||||
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException
|
||||
* if {@code min >= max}.
|
||||
*/
|
||||
public static FeatureInitializer uniform(final double min,
|
||||
final double max) {
|
||||
return randomize(new UniformRealDistribution(min, max),
|
||||
function(new Constant(0), 0, 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an initializer from a univariate function {@code f(x)}.
|
||||
* The argument {@code x} is set to {@code init} at the first call
|
||||
* and will be incremented at each call.
|
||||
*
|
||||
* @param f Function.
|
||||
* @param init Initial value.
|
||||
* @param inc Increment
|
||||
* @return the initializer.
|
||||
*/
|
||||
public static FeatureInitializer function(final UnivariateFunction f,
|
||||
final double init,
|
||||
final double inc) {
|
||||
return new FeatureInitializer() {
|
||||
/** Argument. */
|
||||
private double arg = init;
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double value() {
|
||||
final double result = f.value(arg);
|
||||
arg += inc;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds some amount of random data to the given initializer.
|
||||
*
|
||||
* @param random Random variable distribution.
|
||||
* @param orig Original initializer.
|
||||
* @return an initializer whose {@link FeatureInitializer#value() value}
|
||||
* method will return {@code orig.value() + random.sample()}.
|
||||
*/
|
||||
public static FeatureInitializer randomize(final RealDistribution random,
|
||||
final FeatureInitializer orig) {
|
||||
return new FeatureInitializer() {
|
||||
/** {@inheritDoc} */
|
||||
public double value() {
|
||||
return orig.value() + random.sample();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Collection;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
|
||||
import org.apache.commons.math3.exception.NoDataException;
|
||||
import org.apache.commons.math3.util.Pair;
|
||||
|
||||
/**
|
||||
* Utilities for network maps.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class MapUtils {
|
||||
/**
|
||||
* Class contains only static methods.
|
||||
*/
|
||||
private MapUtils() {}
|
||||
|
||||
/**
|
||||
* Finds the neuron that best matches the given features.
|
||||
*
|
||||
* @param features Data.
|
||||
* @param neurons List of neurons to scan. If the list is empty
|
||||
* {@code null} will be returned.
|
||||
* @param distance Distance function. The neuron's features are
|
||||
* passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
|
||||
* @return the neuron whose features are closest to the given data.
|
||||
* @throws org.apache.commons.math3.exception.DimensionMismatchException
|
||||
* if the size of the input is not compatible with the neurons features
|
||||
* size.
|
||||
*/
|
||||
public static Neuron findBest(double[] features,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
Neuron best = null;
|
||||
double min = Double.POSITIVE_INFINITY;
|
||||
for (final Neuron n : neurons) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
if (d < min) {
|
||||
min = d;
|
||||
best = n;
|
||||
}
|
||||
}
|
||||
|
||||
return best;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the two neurons that best match the given features.
|
||||
*
|
||||
* @param features Data.
|
||||
* @param neurons List of neurons to scan. If the list is empty
|
||||
* {@code null} will be returned.
|
||||
* @param distance Distance function. The neuron's features are
|
||||
* passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
|
||||
* @return the two neurons whose features are closest to the given data.
|
||||
* @throws org.apache.commons.math3.exception.DimensionMismatchException
|
||||
* if the size of the input is not compatible with the neurons features
|
||||
* size.
|
||||
*/
|
||||
public static Pair<Neuron, Neuron> findBestAndSecondBest(double[] features,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
Neuron[] best = { null, null };
|
||||
double[] min = { Double.POSITIVE_INFINITY,
|
||||
Double.POSITIVE_INFINITY };
|
||||
for (final Neuron n : neurons) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
if (d < min[0]) {
|
||||
// Replace second best with old best.
|
||||
min[1] = min[0];
|
||||
best[1] = best[0];
|
||||
|
||||
// Store current as new best.
|
||||
min[0] = d;
|
||||
best[0] = n;
|
||||
} else if (d < min[1]) {
|
||||
// Replace old second best with current.
|
||||
min[1] = d;
|
||||
best[1] = n;
|
||||
}
|
||||
}
|
||||
|
||||
return new Pair<Neuron, Neuron>(best[0], best[1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix">
|
||||
* U-matrix</a> of a two-dimensional map.
|
||||
*
|
||||
* @param map Network.
|
||||
* @param distance Function to use for computing the average
|
||||
* distance from a neuron to its neighbours.
|
||||
* @return the matrix of average distances.
|
||||
*/
|
||||
public static double[][] computeU(NeuronSquareMesh2D map,
|
||||
DistanceMeasure distance) {
|
||||
final int numRows = map.getNumberOfRows();
|
||||
final int numCols = map.getNumberOfColumns();
|
||||
final double[][] uMatrix = new double[numRows][numCols];
|
||||
|
||||
final Network net = map.getNetwork();
|
||||
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
for (int j = 0; j < numCols; j++) {
|
||||
final Neuron neuron = map.getNeuron(i, j);
|
||||
final Collection<Neuron> neighbours = net.getNeighbours(neuron);
|
||||
final double[] features = neuron.getFeatures();
|
||||
|
||||
double d = 0;
|
||||
int count = 0;
|
||||
for (Neuron n : neighbours) {
|
||||
++count;
|
||||
d += distance.compute(features, n.getFeatures());
|
||||
}
|
||||
|
||||
uMatrix[i][j] = d / count;
|
||||
}
|
||||
}
|
||||
|
||||
return uMatrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the "hit" histogram of a two-dimensional map.
|
||||
*
|
||||
* @param data Feature vectors.
|
||||
* @param map Network.
|
||||
* @param distance Function to use for determining the best matching unit.
|
||||
* @return the number of hits for each neuron in the map.
|
||||
*/
|
||||
public static int[][] computeHitHistogram(Iterable<double[]> data,
|
||||
NeuronSquareMesh2D map,
|
||||
DistanceMeasure distance) {
|
||||
final HashMap<Neuron, Integer> hit = new HashMap<Neuron, Integer>();
|
||||
final Network net = map.getNetwork();
|
||||
|
||||
for (double[] f : data) {
|
||||
final Neuron best = findBest(f, net, distance);
|
||||
final Integer count = hit.get(best);
|
||||
if (count == null) {
|
||||
hit.put(best, 1);
|
||||
} else {
|
||||
hit.put(best, count + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the histogram data into a 2D map.
|
||||
final int numRows = map.getNumberOfRows();
|
||||
final int numCols = map.getNumberOfColumns();
|
||||
final int[][] histo = new int[numRows][numCols];
|
||||
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
for (int j = 0; j < numCols; j++) {
|
||||
final Neuron neuron = map.getNeuron(i, j);
|
||||
final Integer count = hit.get(neuron);
|
||||
if (count == null) {
|
||||
histo[i][j] = 0;
|
||||
} else {
|
||||
histo[i][j] = count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return histo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the quantization error.
|
||||
* The quantization error is the average distance between a feature vector
|
||||
* and its "best matching unit" (closest neuron).
|
||||
*
|
||||
* @param data Feature vectors.
|
||||
* @param neurons List of neurons to scan.
|
||||
* @param distance Distance function.
|
||||
* @return the error.
|
||||
* @throws NoDataException if {@code data} is empty.
|
||||
*/
|
||||
public static double computeQuantizationError(Iterable<double[]> data,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
double d = 0;
|
||||
int count = 0;
|
||||
for (double[] f : data) {
|
||||
++count;
|
||||
d += distance.compute(f, findBest(f, neurons, distance).getFeatures());
|
||||
}
|
||||
|
||||
if (count == 0) {
|
||||
throw new NoDataException();
|
||||
}
|
||||
|
||||
return d / count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the topographic error.
|
||||
* The topographic error is the proportion of data for which first and
|
||||
* second best matching units are not adjacent in the map.
|
||||
*
|
||||
* @param data Feature vectors.
|
||||
* @param net Network.
|
||||
* @param distance Distance function.
|
||||
* @return the error.
|
||||
* @throws NoDataException if {@code data} is empty.
|
||||
*/
|
||||
public static double computeTopographicError(Iterable<double[]> data,
|
||||
Network net,
|
||||
DistanceMeasure distance) {
|
||||
int notAdjacentCount = 0;
|
||||
int count = 0;
|
||||
for (double[] f : data) {
|
||||
++count;
|
||||
final Pair<Neuron, Neuron> p = findBestAndSecondBest(f, net, distance);
|
||||
if (!net.getNeighbours(p.getFirst()).contains(p.getSecond())) {
|
||||
// Increment count if first and second best matching units
|
||||
// are not neighbours.
|
||||
++notAdjacentCount;
|
||||
}
|
||||
}
|
||||
|
||||
if (count == 0) {
|
||||
throw new NoDataException();
|
||||
}
|
||||
|
||||
return ((double) notAdjacentCount) / count;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,476 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.Comparator;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.exception.MathIllegalStateException;
|
||||
|
||||
/**
|
||||
* Neural network, composed of {@link Neuron} instances and the links
|
||||
* between them.
|
||||
*
|
||||
* Although updating a neuron's state is thread-safe, modifying the
|
||||
* network's topology (adding or removing links) is not.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class Network
|
||||
implements Iterable<Neuron>,
|
||||
Serializable {
|
||||
/** Serializable. */
|
||||
private static final long serialVersionUID = 20130207L;
|
||||
/** Neurons. */
|
||||
private final ConcurrentHashMap<Long, Neuron> neuronMap
|
||||
= new ConcurrentHashMap<Long, Neuron>();
|
||||
/** Next available neuron identifier. */
|
||||
private final AtomicLong nextId;
|
||||
/** Neuron's features set size. */
|
||||
private final int featureSize;
|
||||
/** Links. */
|
||||
private final ConcurrentHashMap<Long, Set<Long>> linkMap
|
||||
= new ConcurrentHashMap<Long, Set<Long>>();
|
||||
|
||||
/**
|
||||
* Comparator that prescribes an order of the neurons according
|
||||
* to the increasing order of their identifier.
|
||||
*/
|
||||
public static class NeuronIdentifierComparator
|
||||
implements Comparator<Neuron>,
|
||||
Serializable {
|
||||
/** Version identifier. */
|
||||
private static final long serialVersionUID = 20130207L;
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int compare(Neuron a,
|
||||
Neuron b) {
|
||||
final long aId = a.getIdentifier();
|
||||
final long bId = b.getIdentifier();
|
||||
return aId < bId ? -1 :
|
||||
aId > bId ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor with restricted access, solely used for deserialization.
|
||||
*
|
||||
* @param nextId Next available identifier.
|
||||
* @param featureSize Number of features.
|
||||
* @param neuronList Neurons.
|
||||
* @param neighbourIdList Links associated to each of the neurons in
|
||||
* {@code neuronList}.
|
||||
* @throws MathIllegalStateException if an inconsistency is detected
|
||||
* (which probably means that the serialized form has been corrupted).
|
||||
*/
|
||||
Network(long nextId,
|
||||
int featureSize,
|
||||
Neuron[] neuronList,
|
||||
long[][] neighbourIdList) {
|
||||
final int numNeurons = neuronList.length;
|
||||
if (numNeurons != neighbourIdList.length) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
|
||||
for (int i = 0; i < numNeurons; i++) {
|
||||
final Neuron n = neuronList[i];
|
||||
final long id = n.getIdentifier();
|
||||
if (id >= nextId) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
neuronMap.put(id, n);
|
||||
linkMap.put(id, new HashSet<Long>());
|
||||
}
|
||||
|
||||
for (int i = 0; i < numNeurons; i++) {
|
||||
final long aId = neuronList[i].getIdentifier();
|
||||
final Set<Long> aLinks = linkMap.get(aId);
|
||||
for (Long bId : neighbourIdList[i]) {
|
||||
if (neuronMap.get(bId) == null) {
|
||||
throw new MathIllegalStateException();
|
||||
}
|
||||
addLinkToLinkSet(aLinks, bId);
|
||||
}
|
||||
}
|
||||
|
||||
this.nextId = new AtomicLong(nextId);
|
||||
this.featureSize = featureSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param initialIdentifier Identifier for the first neuron that
|
||||
* will be added to this network.
|
||||
* @param featureSize Size of the neuron's features.
|
||||
*/
|
||||
public Network(long initialIdentifier,
|
||||
int featureSize) {
|
||||
nextId = new AtomicLong(initialIdentifier);
|
||||
this.featureSize = featureSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public Iterator<Neuron> iterator() {
|
||||
return neuronMap.values().iterator();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of the neurons, sorted in a custom order.
|
||||
*
|
||||
* @param comparator {@link Comparator} used for sorting the neurons.
|
||||
* @return a list of neurons, sorted in the order prescribed by the
|
||||
* given {@code comparator}.
|
||||
* @see NeuronIdentifierComparator
|
||||
*/
|
||||
public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
|
||||
final List<Neuron> neurons = new ArrayList<Neuron>();
|
||||
neurons.addAll(neuronMap.values());
|
||||
|
||||
Collections.sort(neurons, comparator);
|
||||
|
||||
return neurons;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a neuron and assigns it a unique identifier.
|
||||
*
|
||||
* @param features Initial values for the neuron's features.
|
||||
* @return the neuron's identifier.
|
||||
* @throws DimensionMismatchException if the length of {@code features}
|
||||
* is different from the expected size (as set by the
|
||||
* {@link #Network(long,int) constructor}).
|
||||
*/
|
||||
public long createNeuron(double[] features) {
|
||||
if (features.length != featureSize) {
|
||||
throw new DimensionMismatchException(features.length, featureSize);
|
||||
}
|
||||
|
||||
final long id = createNextId();
|
||||
neuronMap.put(id, new Neuron(id, features));
|
||||
linkMap.put(id, new HashSet<Long>());
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a neuron.
|
||||
* Links from all neighbours to the removed neuron will also be
|
||||
* {@link #deleteLink(Neuron,Neuron) deleted}.
|
||||
*
|
||||
* @param neuron Neuron to be removed from this network.
|
||||
* @throws NoSuchElementException if {@code n} does not belong to
|
||||
* this network.
|
||||
*/
|
||||
public void deleteNeuron(Neuron neuron) {
|
||||
final Collection<Neuron> neighbours = getNeighbours(neuron);
|
||||
|
||||
// Delete links to from neighbours.
|
||||
for (Neuron n : neighbours) {
|
||||
deleteLink(n, neuron);
|
||||
}
|
||||
|
||||
// Remove neuron.
|
||||
neuronMap.remove(neuron.getIdentifier());
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the size of the neurons' features set.
|
||||
*
|
||||
* @return the size of the features set.
|
||||
*/
|
||||
public int getFeaturesSize() {
|
||||
return featureSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a link from neuron {@code a} to neuron {@code b}.
|
||||
* Note: the link is not bi-directional; if a bi-directional link is
|
||||
* required, an additional call must be made with {@code a} and
|
||||
* {@code b} exchanged in the argument list.
|
||||
*
|
||||
* @param a Neuron.
|
||||
* @param b Neuron.
|
||||
* @throws NoSuchElementException if the neurons do not exist in the
|
||||
* network.
|
||||
*/
|
||||
public void addLink(Neuron a,
|
||||
Neuron b) {
|
||||
final long aId = a.getIdentifier();
|
||||
final long bId = b.getIdentifier();
|
||||
|
||||
// Check that the neurons belong to this network.
|
||||
if (a != getNeuron(aId)) {
|
||||
throw new NoSuchElementException(Long.toString(aId));
|
||||
}
|
||||
if (b != getNeuron(bId)) {
|
||||
throw new NoSuchElementException(Long.toString(bId));
|
||||
}
|
||||
|
||||
// Add link from "a" to "b".
|
||||
addLinkToLinkSet(linkMap.get(aId), bId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a link to neuron {@code id} in given {@code linkSet}.
|
||||
* Note: no check verifies that the identifier indeed belongs
|
||||
* to this network.
|
||||
*
|
||||
* @param linkSet Neuron identifier.
|
||||
* @param id Neuron identifier.
|
||||
*/
|
||||
private void addLinkToLinkSet(Set<Long> linkSet,
|
||||
long id) {
|
||||
linkSet.add(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes the link between neurons {@code a} and {@code b}.
|
||||
*
|
||||
* @param a Neuron.
|
||||
* @param b Neuron.
|
||||
* @throws NoSuchElementException if the neurons do not exist in the
|
||||
* network.
|
||||
*/
|
||||
public void deleteLink(Neuron a,
|
||||
Neuron b) {
|
||||
final long aId = a.getIdentifier();
|
||||
final long bId = b.getIdentifier();
|
||||
|
||||
// Check that the neurons belong to this network.
|
||||
if (a != getNeuron(aId)) {
|
||||
throw new NoSuchElementException(Long.toString(aId));
|
||||
}
|
||||
if (b != getNeuron(bId)) {
|
||||
throw new NoSuchElementException(Long.toString(bId));
|
||||
}
|
||||
|
||||
// Delete link from "a" to "b".
|
||||
deleteLinkFromLinkSet(linkMap.get(aId), bId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a link to neuron {@code id} in given {@code linkSet}.
|
||||
* Note: no check verifies that the identifier indeed belongs
|
||||
* to this network.
|
||||
*
|
||||
* @param linkSet Neuron identifier.
|
||||
* @param id Neuron identifier.
|
||||
*/
|
||||
private void deleteLinkFromLinkSet(Set<Long> linkSet,
|
||||
long id) {
|
||||
linkSet.remove(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the neuron with the given (unique) {@code id}.
|
||||
*
|
||||
* @param id Identifier.
|
||||
* @return the neuron associated with the given {@code id}.
|
||||
* @throws NoSuchElementException if the neuron does not exist in the
|
||||
* network.
|
||||
*/
|
||||
public Neuron getNeuron(long id) {
|
||||
final Neuron n = neuronMap.get(id);
|
||||
if (n == null) {
|
||||
throw new NoSuchElementException(Long.toString(id));
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the neurons in the neighbourhood of any neuron in the
|
||||
* {@code neurons} list.
|
||||
* @param neurons Neurons for which to retrieve the neighbours.
|
||||
* @return the list of neighbours.
|
||||
* @see #getNeighbours(Iterable,Iterable)
|
||||
*/
|
||||
public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
|
||||
return getNeighbours(neurons, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the neurons in the neighbourhood of any neuron in the
|
||||
* {@code neurons} list.
|
||||
* The {@code exclude} list allows to retrieve the "concentric"
|
||||
* neighbourhoods by removing the neurons that belong to the inner
|
||||
* "circles".
|
||||
*
|
||||
* @param neurons Neurons for which to retrieve the neighbours.
|
||||
* @param exclude Neurons to exclude from the returned list.
|
||||
* Can be {@code null}.
|
||||
* @return the list of neighbours.
|
||||
*/
|
||||
public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
|
||||
Iterable<Neuron> exclude) {
|
||||
final Set<Long> idList = new HashSet<Long>();
|
||||
|
||||
for (Neuron n : neurons) {
|
||||
idList.addAll(linkMap.get(n.getIdentifier()));
|
||||
}
|
||||
if (exclude != null) {
|
||||
for (Neuron n : exclude) {
|
||||
idList.remove(n.getIdentifier());
|
||||
}
|
||||
}
|
||||
|
||||
final List<Neuron> neuronList = new ArrayList<Neuron>();
|
||||
for (Long id : idList) {
|
||||
neuronList.add(getNeuron(id));
|
||||
}
|
||||
|
||||
return neuronList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the neighbours of the given neuron.
|
||||
*
|
||||
* @param neuron Neuron for which to retrieve the neighbours.
|
||||
* @return the list of neighbours.
|
||||
* @see #getNeighbours(Neuron,Iterable)
|
||||
*/
|
||||
public Collection<Neuron> getNeighbours(Neuron neuron) {
|
||||
return getNeighbours(neuron, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the neighbours of the given neuron.
|
||||
*
|
||||
* @param neuron Neuron for which to retrieve the neighbours.
|
||||
* @param exclude Neurons to exclude from the returned list.
|
||||
* Can be {@code null}.
|
||||
* @return the list of neighbours.
|
||||
*/
|
||||
public Collection<Neuron> getNeighbours(Neuron neuron,
|
||||
Iterable<Neuron> exclude) {
|
||||
final Set<Long> idList = linkMap.get(neuron.getIdentifier());
|
||||
if (exclude != null) {
|
||||
for (Neuron n : exclude) {
|
||||
idList.remove(n.getIdentifier());
|
||||
}
|
||||
}
|
||||
|
||||
final List<Neuron> neuronList = new ArrayList<Neuron>();
|
||||
for (Long id : idList) {
|
||||
neuronList.add(getNeuron(id));
|
||||
}
|
||||
|
||||
return neuronList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a neuron identifier.
|
||||
*
|
||||
* @return a value that will serve as a unique identifier.
|
||||
*/
|
||||
private Long createNextId() {
|
||||
return nextId.getAndIncrement();
|
||||
}
|
||||
|
||||
/**
|
||||
* Prevents proxy bypass.
|
||||
*
|
||||
* @param in Input stream.
|
||||
*/
|
||||
private void readObject(ObjectInputStream in) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the proxy instance that will be actually serialized.
|
||||
*/
|
||||
private Object writeReplace() {
|
||||
final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
|
||||
final long[][] neighbourIdList = new long[neuronList.length][];
|
||||
|
||||
for (int i = 0; i < neuronList.length; i++) {
|
||||
final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
|
||||
final long[] neighboursId = new long[neighbours.size()];
|
||||
int count = 0;
|
||||
for (Neuron n : neighbours) {
|
||||
neighboursId[count] = n.getIdentifier();
|
||||
++count;
|
||||
}
|
||||
neighbourIdList[i] = neighboursId;
|
||||
}
|
||||
|
||||
return new SerializationProxy(nextId.get(),
|
||||
featureSize,
|
||||
neuronList,
|
||||
neighbourIdList);
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialization.
|
||||
*/
|
||||
private static class SerializationProxy implements Serializable {
|
||||
/** Serializable. */
|
||||
private static final long serialVersionUID = 20130207L;
|
||||
/** Next identifier. */
|
||||
private final long nextId;
|
||||
/** Number of features. */
|
||||
private final int featureSize;
|
||||
/** Neurons. */
|
||||
private final Neuron[] neuronList;
|
||||
/** Links. */
|
||||
private final long[][] neighbourIdList;
|
||||
|
||||
/**
|
||||
* @param nextId Next available identifier.
|
||||
* @param featureSize Number of features.
|
||||
* @param neuronList Neurons.
|
||||
* @param neighbourIdList Links associated to each of the neurons in
|
||||
* {@code neuronList}.
|
||||
*/
|
||||
SerializationProxy(long nextId,
|
||||
int featureSize,
|
||||
Neuron[] neuronList,
|
||||
long[][] neighbourIdList) {
|
||||
this.nextId = nextId;
|
||||
this.featureSize = featureSize;
|
||||
this.neuronList = neuronList;
|
||||
this.neighbourIdList = neighbourIdList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the {@link Network} for which this instance is the proxy.
|
||||
*/
|
||||
private Object readResolve() {
|
||||
return new Network(nextId,
|
||||
featureSize,
|
||||
neuronList,
|
||||
neighbourIdList);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
|
||||
|
||||
/**
|
||||
* Describes a neuron element of a neural network.
|
||||
*
|
||||
* This class aims to be thread-safe.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class Neuron implements Serializable {
|
||||
/** Serializable. */
|
||||
private static final long serialVersionUID = 20130207L;
|
||||
/** Identifier. */
|
||||
private final long identifier;
|
||||
/** Length of the feature set. */
|
||||
private final int size;
|
||||
/** Neuron data. */
|
||||
private final AtomicReference<double[]> features;
|
||||
|
||||
/**
|
||||
* Creates a neuron.
|
||||
* The size of the feature set is fixed to the length of the given
|
||||
* argument.
|
||||
* <br/>
|
||||
* Constructor is package-private: Neurons must be
|
||||
* {@link Network#createNeuron(double[]) created} by the network
|
||||
* instance to which they will belong.
|
||||
*
|
||||
* @param identifier Identifier (assigned by the {@link Network}).
|
||||
* @param features Initial values of the feature set.
|
||||
*/
|
||||
Neuron(long identifier,
|
||||
double[] features) {
|
||||
this.identifier = identifier;
|
||||
this.size = features.length;
|
||||
this.features = new AtomicReference<double[]>(features.clone());
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the neuron's identifier.
|
||||
*
|
||||
* @return the identifier.
|
||||
*/
|
||||
public long getIdentifier() {
|
||||
return identifier;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the length of the feature set.
|
||||
*
|
||||
* @return the number of features.
|
||||
*/
|
||||
public int getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the neuron's features.
|
||||
*
|
||||
* @return a copy of the neuron's features.
|
||||
*/
|
||||
public double[] getFeatures() {
|
||||
return features.get().clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to atomically update the neuron's features.
|
||||
* Update will be performed only if the expected values match the
|
||||
* current values.<br/>
|
||||
* In effect, when concurrent threads call this method, the state
|
||||
* could be modified by one, so that it does not correspond to the
|
||||
* the state assumed by another.
|
||||
* Typically, a caller {@link #getFeatures() retrieves the current state},
|
||||
* and uses it to compute the new state.
|
||||
* During this computation, another thread might have done the same
|
||||
* thing, and updated the state: If the current thread were to proceed
|
||||
* with its own update, it would overwrite the new state (which might
|
||||
* already have been used by yet other threads).
|
||||
* To prevent this, the method does not perform the update when a
|
||||
* concurrent modification has been detected, and returns {@code false}.
|
||||
* When this happens, the caller should fetch the new current state,
|
||||
* redo its computation, and call this method again.
|
||||
*
|
||||
* @param expect Current values of the features, as assumed by the caller.
|
||||
* Update will never succeed if the contents of this array does not match
|
||||
* the values returned by {@link #getFeatures()}.
|
||||
* @param update Features's new values.
|
||||
* @return {@code true} if the update was successful, {@code false}
|
||||
* otherwise.
|
||||
* @throws DimensionMismatchException if the length of {@code update} is
|
||||
* not the same as specified in the {@link #Neuron(long,double[])
|
||||
* constructor}.
|
||||
*/
|
||||
public boolean compareAndSetFeatures(double[] expect,
|
||||
double[] update) {
|
||||
if (update.length != size) {
|
||||
throw new DimensionMismatchException(update.length, size);
|
||||
}
|
||||
|
||||
// Get the internal reference. Note that this must not be a copy;
|
||||
// otherwise the "compareAndSet" below will always fail.
|
||||
final double[] current = features.get();
|
||||
if (!containSameValues(current, expect)) {
|
||||
// Some other thread already modified the state.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (features.compareAndSet(current, update.clone())) {
|
||||
// The current thread could atomically update the state.
|
||||
return true;
|
||||
} else {
|
||||
// Some other thread came first.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether the contents of both arrays is the same.
|
||||
*
|
||||
* @param current Current values.
|
||||
* @param expect Expected values.
|
||||
* @throws DimensionMismatchException if the length of {@code expected}
|
||||
* is not the same as specified in the {@link #Neuron(long,double[])
|
||||
* constructor}.
|
||||
* @return {@code true} if the arrays contain the same values.
|
||||
*/
|
||||
private boolean containSameValues(double[] current,
|
||||
double[] expect) {
|
||||
if (expect.length != size) {
|
||||
throw new DimensionMismatchException(expect.length, size);
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (!Precision.equals(current[i], expect[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prevents proxy bypass.
|
||||
*
|
||||
* @param in Input stream.
|
||||
*/
|
||||
private void readObject(ObjectInputStream in) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the proxy instance that will be actually serialized.
|
||||
*/
|
||||
private Object writeReplace() {
|
||||
return new SerializationProxy(identifier,
|
||||
features.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialization.
|
||||
*/
|
||||
private static class SerializationProxy implements Serializable {
|
||||
/** Serializable. */
|
||||
private static final long serialVersionUID = 20130207L;
|
||||
/** Features. */
|
||||
private final double[] features;
|
||||
/** Identifier. */
|
||||
private final long identifier;
|
||||
|
||||
/**
|
||||
* @param identifier Identifier.
|
||||
* @param features Features.
|
||||
*/
|
||||
SerializationProxy(long identifier,
|
||||
double[] features) {
|
||||
this.identifier = identifier;
|
||||
this.features = features;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the {@link Neuron} for which this instance is the proxy.
|
||||
*/
|
||||
private Object readResolve() {
|
||||
return new Neuron(identifier,
|
||||
features);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
/**
|
||||
* Defines neighbourhood types.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public enum SquareNeighbourhood {
|
||||
/**
|
||||
* <a href="http://en.wikipedia.org/wiki/Von_Neumann_neighborhood"
|
||||
* Von Neumann neighbourhood</a>: in two dimensions, each (internal)
|
||||
* neuron has four neighbours.
|
||||
*/
|
||||
VON_NEUMANN,
|
||||
/**
|
||||
* <a href="http://en.wikipedia.org/wiki/Moore_neighborhood"
|
||||
* Moore neighbourhood</a>: in two dimensions, each (internal)
|
||||
* neuron has eight neighbours.
|
||||
*/
|
||||
MOORE,
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
/**
|
||||
* Describes how to update the network in response to a training
|
||||
* sample.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public interface UpdateAction {
|
||||
/**
|
||||
* Updates the network in response to the sample {@code features}.
|
||||
*
|
||||
* @param net Network.
|
||||
* @param features Training data.
|
||||
*/
|
||||
void update(Network net, double[] features);
|
||||
}
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.oned;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.io.ObjectInputStream;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
/**
|
||||
* Neural network with the topology of a one-dimensional line.
|
||||
* Each neuron defines one point on the line.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class NeuronString implements Serializable {
|
||||
/** Underlying network. */
|
||||
private final Network network;
|
||||
/** Number of neurons. */
|
||||
private final int size;
|
||||
/** Wrap. */
|
||||
private final boolean wrap;
|
||||
|
||||
/**
|
||||
* Mapping of the 1D coordinate to the neuron identifiers
|
||||
* (attributed by the {@link #network} instance).
|
||||
*/
|
||||
private final long[] identifiers;
|
||||
|
||||
/**
|
||||
* Constructor with restricted access, solely used for deserialization.
|
||||
*
|
||||
* @param wrap Whether to wrap the dimension (i.e the first and last
|
||||
* neurons will be linked together).
|
||||
* @param featuresList Arrays that will initialize the features sets of
|
||||
* the network's neurons.
|
||||
* @throws NumberIsTooSmallException if {@code num < 2}.
|
||||
*/
|
||||
NeuronString(boolean wrap,
|
||||
double[][] featuresList) {
|
||||
size = featuresList.length;
|
||||
|
||||
if (size < 2) {
|
||||
throw new NumberIsTooSmallException(size, 2, true);
|
||||
}
|
||||
|
||||
this.wrap = wrap;
|
||||
|
||||
final int fLen = featuresList[0].length;
|
||||
network = new Network(0, fLen);
|
||||
identifiers = new long[size];
|
||||
|
||||
// Add neurons.
|
||||
for (int i = 0; i < size; i++) {
|
||||
identifiers[i] = network.createNeuron(featuresList[i]);
|
||||
}
|
||||
|
||||
// Add links.
|
||||
createLinks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a one-dimensional network:
|
||||
* Each neuron not located on the border of the mesh has two
|
||||
* neurons linked to it.
|
||||
* <br/>
|
||||
* The links are bi-directional.
|
||||
* Neurons created successively are neighbours (i.e. there are
|
||||
* links between them).
|
||||
* <br/>
|
||||
* The topology of the network can also be a circle (if the
|
||||
* dimension is wrapped).
|
||||
*
|
||||
* @param num Number of neurons.
|
||||
* @param wrap Whether to wrap the dimension (i.e the first and last
|
||||
* neurons will be linked together).
|
||||
* @param featureInit Arrays that will initialize the features sets of
|
||||
* the network's neurons.
|
||||
* @throws NumberIsTooSmallException if {@code num < 2}.
|
||||
*/
|
||||
public NeuronString(int num,
|
||||
boolean wrap,
|
||||
FeatureInitializer[] featureInit) {
|
||||
if (num < 2) {
|
||||
throw new NumberIsTooSmallException(num, 2, true);
|
||||
}
|
||||
|
||||
size = num;
|
||||
this.wrap = wrap;
|
||||
identifiers = new long[num];
|
||||
|
||||
final int fLen = featureInit.length;
|
||||
network = new Network(0, fLen);
|
||||
|
||||
// Add neurons.
|
||||
for (int i = 0; i < num; i++) {
|
||||
final double[] features = new double[fLen];
|
||||
for (int fIndex = 0; fIndex < fLen; fIndex++) {
|
||||
features[fIndex] = featureInit[fIndex].value();
|
||||
}
|
||||
identifiers[i] = network.createNeuron(features);
|
||||
}
|
||||
|
||||
// Add links.
|
||||
createLinks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the underlying network.
|
||||
* A reference is returned (enabling, for example, the network to be
|
||||
* trained).
|
||||
* This also implies that calling methods that modify the {@link Network}
|
||||
* topology may cause this class to become inconsistent.
|
||||
*
|
||||
* @return the network.
|
||||
*/
|
||||
public Network getNetwork() {
|
||||
return network;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of neurons.
|
||||
*
|
||||
* @return the number of neurons.
|
||||
*/
|
||||
public int getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the features set from the neuron at location
|
||||
* {@code i} in the map.
|
||||
*
|
||||
* @param i Neuron index.
|
||||
* @return the features of the neuron at index {@code i}.
|
||||
* @throws OutOfRangeException if {@code i} is out of range.
|
||||
*/
|
||||
public double[] getFeatures(int i) {
|
||||
if (i < 0 ||
|
||||
i >= size) {
|
||||
throw new OutOfRangeException(i, 0, size - 1);
|
||||
}
|
||||
|
||||
return network.getNeuron(identifiers[i]).getFeatures();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the neighbour relationships between neurons.
|
||||
*/
|
||||
private void createLinks() {
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
network.addLink(network.getNeuron(i), network.getNeuron(i + 1));
|
||||
}
|
||||
for (int i = size - 1; i > 0; i--) {
|
||||
network.addLink(network.getNeuron(i), network.getNeuron(i - 1));
|
||||
}
|
||||
if (wrap) {
|
||||
network.addLink(network.getNeuron(0), network.getNeuron(size - 1));
|
||||
network.addLink(network.getNeuron(size - 1), network.getNeuron(0));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prevents proxy bypass.
|
||||
*
|
||||
* @param in Input stream.
|
||||
*/
|
||||
private void readObject(ObjectInputStream in) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the proxy instance that will be actually serialized.
|
||||
*/
|
||||
private Object writeReplace() {
|
||||
final double[][] featuresList = new double[size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
featuresList[i] = getFeatures(i);
|
||||
}
|
||||
|
||||
return new SerializationProxy(wrap,
|
||||
featuresList);
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialization.
|
||||
*/
|
||||
private static class SerializationProxy implements Serializable {
|
||||
/** Serializable. */
|
||||
private static final long serialVersionUID = 20130226L;
|
||||
/** Wrap. */
|
||||
private final boolean wrap;
|
||||
/** Neurons' features. */
|
||||
private final double[][] featuresList;
|
||||
|
||||
/**
|
||||
* @param wrap Whether the dimension is wrapped.
|
||||
* @param featuresList List of neurons features.
|
||||
* {@code neuronList}.
|
||||
*/
|
||||
SerializationProxy(boolean wrap,
|
||||
double[][] featuresList) {
|
||||
this.wrap = wrap;
|
||||
this.featuresList = featuresList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the {@link Neuron} for which this instance is the proxy.
|
||||
*/
|
||||
private Object readResolve() {
|
||||
return new NeuronString(wrap,
|
||||
featuresList);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* One-dimensional neural networks.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.oned;
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Neural networks.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import java.util.Iterator;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
|
||||
/**
|
||||
* Trainer for Kohonen's Self-Organizing Map.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class KohonenTrainingTask implements Runnable {
|
||||
/** SOFM to be trained. */
|
||||
private final Network net;
|
||||
/** Training data. */
|
||||
private final Iterator<double[]> featuresIterator;
|
||||
/** Update procedure. */
|
||||
private final KohonenUpdateAction updateAction;
|
||||
|
||||
/**
|
||||
* Creates a (sequential) trainer for the given network.
|
||||
*
|
||||
* @param net Network to be trained with the SOFM algorithm.
|
||||
* @param featuresIterator Training data iterator.
|
||||
* @param updateAction SOFM update procedure.
|
||||
*/
|
||||
public KohonenTrainingTask(Network net,
|
||||
Iterator<double[]> featuresIterator,
|
||||
KohonenUpdateAction updateAction) {
|
||||
this.net = net;
|
||||
this.featuresIterator = featuresIterator;
|
||||
this.updateAction = updateAction;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
public void run() {
|
||||
while (featuresIterator.hasNext()) {
|
||||
updateAction.update(net, featuresIterator.next());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,213 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.MapUtils;
|
||||
import org.apache.commons.math3.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math3.ml.neuralnet.UpdateAction;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.commons.math3.analysis.function.Gaussian;
|
||||
|
||||
/**
|
||||
* Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
|
||||
* Kohonen's Self-Organizing Map</a>.
|
||||
* <br/>
|
||||
* The {@link #update(Network,double[]) update} method modifies the
|
||||
* features {@code w} of the "winning" neuron and its neighbours
|
||||
* according to the following rule:
|
||||
* <code>
|
||||
* w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>)
|
||||
* </code>
|
||||
* where
|
||||
* <ul>
|
||||
* <li>α is the current <em>learning rate</em>, </li>
|
||||
* <li>σ is the current <em>neighbourhood size</em>, and</li>
|
||||
* <li>{@code d} is the number of links to traverse in order to reach
|
||||
* the neuron from the winning neuron.</li>
|
||||
* </ul>
|
||||
* <br/>
|
||||
* This class is thread-safe as long as the arguments passed to the
|
||||
* {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
|
||||
* NeighbourhoodSizeFunction) constructor} are instances of thread-safe
|
||||
* classes.
|
||||
* <br/>
|
||||
* Each call to the {@link #update(Network,double[]) update} method
|
||||
* will increment the internal counter used to compute the current
|
||||
* values for
|
||||
* <ul>
|
||||
* <li>the <em>learning rate</em>, and</li>
|
||||
* <li>the <em>neighbourhood size</em>.</li>
|
||||
* </ul>
|
||||
* Consequently, the function instances that compute those values (passed
|
||||
* to the constructor of this class) must take into account whether this
|
||||
* class's instance will be shared by multiple threads, as this will impact
|
||||
* the training process.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class KohonenUpdateAction implements UpdateAction {
|
||||
/** Distance function. */
|
||||
private final DistanceMeasure distance;
|
||||
/** Learning factor update function. */
|
||||
private final LearningFactorFunction learningFactor;
|
||||
/** Neighbourhood size update function. */
|
||||
private final NeighbourhoodSizeFunction neighbourhoodSize;
|
||||
/** Number of calls to {@link #update(Network,double[])}. */
|
||||
private final AtomicLong numberOfCalls = new AtomicLong(-1);
|
||||
|
||||
/**
|
||||
* @param distance Distance function.
|
||||
* @param learningFactor Learning factor update function.
|
||||
* @param neighbourhoodSize Neighbourhood size update function.
|
||||
*/
|
||||
public KohonenUpdateAction(DistanceMeasure distance,
|
||||
LearningFactorFunction learningFactor,
|
||||
NeighbourhoodSizeFunction neighbourhoodSize) {
|
||||
this.distance = distance;
|
||||
this.learningFactor = learningFactor;
|
||||
this.neighbourhoodSize = neighbourhoodSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void update(Network net,
|
||||
double[] features) {
|
||||
final long numCalls = numberOfCalls.incrementAndGet();
|
||||
final double currentLearning = learningFactor.value(numCalls);
|
||||
final Neuron best = findAndUpdateBestNeuron(net,
|
||||
features,
|
||||
currentLearning);
|
||||
|
||||
final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
|
||||
// The farther away the neighbour is from the winning neuron, the
|
||||
// smaller the learning rate will become.
|
||||
final Gaussian neighbourhoodDecay
|
||||
= new Gaussian(currentLearning,
|
||||
0,
|
||||
1d / currentNeighbourhood);
|
||||
|
||||
if (currentNeighbourhood > 0) {
|
||||
// Initial set of neurons only contains the winning neuron.
|
||||
Collection<Neuron> neighbours = new HashSet<Neuron>();
|
||||
neighbours.add(best);
|
||||
// Winning neuron must be excluded from the neighbours.
|
||||
final HashSet<Neuron> exclude = new HashSet<Neuron>();
|
||||
exclude.add(best);
|
||||
|
||||
int radius = 1;
|
||||
do {
|
||||
// Retrieve immediate neighbours of the current set of neurons.
|
||||
neighbours = net.getNeighbours(neighbours, exclude);
|
||||
|
||||
// Update all the neighbours.
|
||||
for (Neuron n : neighbours) {
|
||||
updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius));
|
||||
}
|
||||
|
||||
// Add the neighbours to the exclude list so that they will
|
||||
// not be update more than once per training step.
|
||||
exclude.addAll(neighbours);
|
||||
++radius;
|
||||
} while (radius <= currentNeighbourhood);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the number of calls to the {@link #update(Network,double[]) update}
|
||||
* method.
|
||||
*
|
||||
* @return the current number of calls.
|
||||
*/
|
||||
public long getNumberOfCalls() {
|
||||
return numberOfCalls.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomically updates the given neuron.
|
||||
*
|
||||
* @param n Neuron to be updated.
|
||||
* @param features Training data.
|
||||
* @param learningRate Learning factor.
|
||||
*/
|
||||
private void updateNeighbouringNeuron(Neuron n,
|
||||
double[] features,
|
||||
double learningRate) {
|
||||
while (true) {
|
||||
final double[] expect = n.getFeatures();
|
||||
final double[] update = computeFeatures(expect,
|
||||
features,
|
||||
learningRate);
|
||||
if (n.compareAndSetFeatures(expect, update)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Searches for the neuron whose features are closest to the given
|
||||
* sample, and atomically updates its features.
|
||||
*
|
||||
* @param net Network.
|
||||
* @param features Sample data.
|
||||
* @param learningRate Current learning factor.
|
||||
* @return the winning neuron.
|
||||
*/
|
||||
private Neuron findAndUpdateBestNeuron(Network net,
|
||||
double[] features,
|
||||
double learningRate) {
|
||||
while (true) {
|
||||
final Neuron best = MapUtils.findBest(features, net, distance);
|
||||
|
||||
final double[] expect = best.getFeatures();
|
||||
final double[] update = computeFeatures(expect,
|
||||
features,
|
||||
learningRate);
|
||||
if (best.compareAndSetFeatures(expect, update)) {
|
||||
return best;
|
||||
}
|
||||
|
||||
// If another thread modified the state of the winning neuron,
|
||||
// it may not be the best match anymore for the given training
|
||||
// sample: Hence, the winner search is performed again.
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the new value of the features set.
|
||||
*
|
||||
* @param current Current values of the features.
|
||||
* @param sample Training data.
|
||||
* @param learningRate Learning factor.
|
||||
* @return the new values for the features.
|
||||
*/
|
||||
private double[] computeFeatures(double[] current,
|
||||
double[] sample,
|
||||
double learningRate) {
|
||||
final ArrayRealVector c = new ArrayRealVector(current, false);
|
||||
final ArrayRealVector s = new ArrayRealVector(sample, false);
|
||||
// c + learningRate * (s - c)
|
||||
return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
/**
|
||||
* Provides the learning rate as a function of the number of calls
|
||||
* already performed during the learning task.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public interface LearningFactorFunction {
|
||||
/**
|
||||
* Computes the learning rate at the current call.
|
||||
*
|
||||
* @param numCall Current step of the training task.
|
||||
* @return the value of the function at {@code numCall}.
|
||||
*/
|
||||
double value(long numCall);
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
/**
|
||||
* Factory for creating instances of {@link LearningFactorFunction}.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class LearningFactorFunctionFactory {
|
||||
/** Class contains only static methods. */
|
||||
private LearningFactorFunctionFactory() {}
|
||||
|
||||
/**
|
||||
* Creates an exponential decay {@link LearningFactorFunction function}.
|
||||
* It will compute <code>a e<sup>-x / b</sup></code>,
|
||||
* where {@code x} is the (integer) independent variable and
|
||||
* <ul>
|
||||
* <li><code>a = initValue</code>
|
||||
* <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code>
|
||||
* </ul>
|
||||
*
|
||||
* @param initValue Initial value, i.e.
|
||||
* {@link LearningFactorFunction#value(long) value(0)}.
|
||||
* @param valueAtNumCall Value of the function at {@code numCall}.
|
||||
* @param numCall Argument for which the function returns
|
||||
* {@code valueAtNumCall}.
|
||||
* @return the learning factor function.
|
||||
* @throws org.apache.commons.math3.exception.OutOfRangeException
|
||||
* if {@code initValue <= 0} or {@code initValue > 1}.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code valueAtNumCall <= 0}.
|
||||
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException
|
||||
* if {@code valueAtNumCall >= initValue}.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code numCall <= 0}.
|
||||
*/
|
||||
public static LearningFactorFunction exponentialDecay(final double initValue,
|
||||
final double valueAtNumCall,
|
||||
final long numCall) {
|
||||
if (initValue <= 0 ||
|
||||
initValue > 1) {
|
||||
throw new OutOfRangeException(initValue, 0, 1);
|
||||
}
|
||||
|
||||
return new LearningFactorFunction() {
|
||||
/** DecayFunction. */
|
||||
private final ExponentialDecayFunction decay
|
||||
= new ExponentialDecayFunction(initValue, valueAtNumCall, numCall);
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double value(long n) {
|
||||
return decay.value(n);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an sigmoid-like {@code LearningFactorFunction function}.
|
||||
* The function {@code f} will have the following properties:
|
||||
* <ul>
|
||||
* <li>{@code f(0) = initValue}</li>
|
||||
* <li>{@code numCall} is the inflexion point</li>
|
||||
* <li>{@code slope = f'(numCall)}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param initValue Initial value, i.e.
|
||||
* {@link LearningFactorFunction#value(long) value(0)}.
|
||||
* @param slope Value of the function derivative at {@code numCall}.
|
||||
* @param numCall Inflexion point.
|
||||
* @return the learning factor function.
|
||||
* @throws org.apache.commons.math3.exception.OutOfRangeException
|
||||
* if {@code initValue <= 0} or {@code initValue > 1}.
|
||||
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException
|
||||
* if {@code slope >= 0}.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code numCall <= 0}.
|
||||
*/
|
||||
public static LearningFactorFunction quasiSigmoidDecay(final double initValue,
|
||||
final double slope,
|
||||
final long numCall) {
|
||||
if (initValue <= 0 ||
|
||||
initValue > 1) {
|
||||
throw new OutOfRangeException(initValue, 0, 1);
|
||||
}
|
||||
|
||||
return new LearningFactorFunction() {
|
||||
/** DecayFunction. */
|
||||
private final QuasiSigmoidDecayFunction decay
|
||||
= new QuasiSigmoidDecayFunction(initValue, slope, numCall);
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public double value(long n) {
|
||||
return decay.value(n);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
/**
|
||||
* Provides the network neighbourhood's size as a function of the
|
||||
* number of calls already performed during the learning task.
|
||||
* The "neighbourhood" is the set of neurons that can be reached
|
||||
* by traversing at most the number of links returned by this
|
||||
* function.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public interface NeighbourhoodSizeFunction {
|
||||
/**
|
||||
* Computes the neighbourhood size at the current call.
|
||||
*
|
||||
* @param numCall Current step of the training task.
|
||||
* @return the value of the function at {@code numCall}.
|
||||
*/
|
||||
int value(long numCall);
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.util.ExponentialDecayFunction;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.util.QuasiSigmoidDecayFunction;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
/**
|
||||
* Factory for creating instances of {@link NeighbourhoodSizeFunction}.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class NeighbourhoodSizeFunctionFactory {
|
||||
/** Class contains only static methods. */
|
||||
private NeighbourhoodSizeFunctionFactory() {}
|
||||
|
||||
/**
|
||||
* Creates an exponential decay {@link NeighbourhoodSizeFunction function}.
|
||||
* It will compute <code>a e<sup>-x / b</sup></code>,
|
||||
* where {@code x} is the (integer) independent variable and
|
||||
* <ul>
|
||||
* <li><code>a = initValue</code>
|
||||
* <li><code>b = -numCall / ln(valueAtNumCall / initValue)</code>
|
||||
* </ul>
|
||||
*
|
||||
* @param initValue Initial value, i.e.
|
||||
* {@link NeighbourhoodSizeFunction#value(long) value(0)}.
|
||||
* @param valueAtNumCall Value of the function at {@code numCall}.
|
||||
* @param numCall Argument for which the function returns
|
||||
* {@code valueAtNumCall}.
|
||||
* @return the neighbourhood size function.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code initValue <= 0}.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code valueAtNumCall <= 0}.
|
||||
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException
|
||||
* if {@code valueAtNumCall >= initValue}.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code numCall <= 0}.
|
||||
*/
|
||||
public static NeighbourhoodSizeFunction exponentialDecay(final double initValue,
|
||||
final double valueAtNumCall,
|
||||
final long numCall) {
|
||||
return new NeighbourhoodSizeFunction() {
|
||||
/** DecayFunction. */
|
||||
private final ExponentialDecayFunction decay
|
||||
= new ExponentialDecayFunction(initValue, valueAtNumCall, numCall);
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int value(long n) {
|
||||
return (int) FastMath.rint(decay.value(n));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an sigmoid-like {@code NeighbourhoodSizeFunction function}.
|
||||
* The function {@code f} will have the following properties:
|
||||
* <ul>
|
||||
* <li>{@code f(0) = initValue}</li>
|
||||
* <li>{@code numCall} is the inflexion point</li>
|
||||
* <li>{@code slope = f'(numCall)}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param initValue Initial value, i.e.
|
||||
* {@link NeighbourhoodSizeFunction#value(long) value(0)}.
|
||||
* @param slope Value of the function derivative at {@code numCall}.
|
||||
* @param numCall Inflexion point.
|
||||
* @return the neighbourhood size function.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code initValue <= 0}.
|
||||
* @throws org.apache.commons.math3.exception.NumberIsTooLargeException
|
||||
* if {@code slope >= 0}.
|
||||
* @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
|
||||
* if {@code numCall <= 0}.
|
||||
*/
|
||||
public static NeighbourhoodSizeFunction quasiSigmoidDecay(final double initValue,
|
||||
final double slope,
|
||||
final long numCall) {
|
||||
return new NeighbourhoodSizeFunction() {
|
||||
/** DecayFunction. */
|
||||
private final QuasiSigmoidDecayFunction decay
|
||||
= new QuasiSigmoidDecayFunction(initValue, slope, numCall);
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int value(long n) {
|
||||
return (int) FastMath.rint(decay.value(n));
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Self Organizing Feature Map.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm.util;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
|
||||
/**
|
||||
* Exponential decay function: <code>a e<sup>-x / b</sup></code>,
|
||||
* where {@code x} is the (integer) independent variable.
|
||||
* <br/>
|
||||
* Class is immutable.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class ExponentialDecayFunction {
|
||||
/** Factor {@code a}. */
|
||||
private final double a;
|
||||
/** Factor {@code 1 / b}. */
|
||||
private final double oneOverB;
|
||||
|
||||
/**
|
||||
* Creates an instance. It will be such that
|
||||
* <ul>
|
||||
* <li>{@code a = initValue}</li>
|
||||
* <li>{@code b = -numCall / ln(valueAtNumCall / initValue)}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param initValue Initial value, i.e. {@link #value(long) value(0)}.
|
||||
* @param valueAtNumCall Value of the function at {@code numCall}.
|
||||
* @param numCall Argument for which the function returns
|
||||
* {@code valueAtNumCall}.
|
||||
* @throws NotStrictlyPositiveException if {@code initValue <= 0}.
|
||||
* @throws NotStrictlyPositiveException if {@code valueAtNumCall <= 0}.
|
||||
* @throws NumberIsTooLargeException if {@code valueAtNumCall >= initValue}.
|
||||
* @throws NotStrictlyPositiveException if {@code numCall <= 0}.
|
||||
*/
|
||||
public ExponentialDecayFunction(double initValue,
|
||||
double valueAtNumCall,
|
||||
long numCall) {
|
||||
if (initValue <= 0) {
|
||||
throw new NotStrictlyPositiveException(initValue);
|
||||
}
|
||||
if (valueAtNumCall <= 0) {
|
||||
throw new NotStrictlyPositiveException(valueAtNumCall);
|
||||
}
|
||||
if (valueAtNumCall >= initValue) {
|
||||
throw new NumberIsTooLargeException(valueAtNumCall, initValue, false);
|
||||
}
|
||||
if (numCall <= 0) {
|
||||
throw new NotStrictlyPositiveException(numCall);
|
||||
}
|
||||
|
||||
a = initValue;
|
||||
oneOverB = -FastMath.log(valueAtNumCall / initValue) / numCall;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes <code>a e<sup>-numCall / b</sup></code>.
|
||||
*
|
||||
* @param numCall Current step of the training task.
|
||||
* @return the value of the function at {@code numCall}.
|
||||
*/
|
||||
public double value(long numCall) {
|
||||
return a * FastMath.exp(-numCall * oneOverB);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm.util;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.analysis.function.Logistic;
|
||||
|
||||
/**
|
||||
* Decay function whose shape is similar to a sigmoid.
|
||||
* <br/>
|
||||
* Class is immutable.
|
||||
*
|
||||
* @version $Id$
|
||||
*/
|
||||
public class QuasiSigmoidDecayFunction {
|
||||
/** Sigmoid. */
|
||||
private final Logistic sigmoid;
|
||||
/** See {@link #value(long)}. */
|
||||
private final double scale;
|
||||
|
||||
/**
|
||||
* Creates an instance.
|
||||
* The function {@code f} will have the following properties:
|
||||
* <ul>
|
||||
* <li>{@code f(0) = initValue}</li>
|
||||
* <li>{@code numCall} is the inflexion point</li>
|
||||
* <li>{@code slope = f'(numCall)}</li>
|
||||
* </ul>
|
||||
*
|
||||
* @param initValue Initial value, i.e. {@link #value(long) value(0)}.
|
||||
* @param slope Value of the function derivative at {@code numCall}.
|
||||
* @param numCall Inflexion point.
|
||||
* @throws NotStrictlyPositiveException if {@code initValue <= 0}.
|
||||
* @throws NumberIsTooLargeException if {@code slope >= 0}.
|
||||
* @throws NotStrictlyPositiveException if {@code numCall <= 0}.
|
||||
*/
|
||||
public QuasiSigmoidDecayFunction(double initValue,
|
||||
double slope,
|
||||
long numCall) {
|
||||
if (initValue <= 0) {
|
||||
throw new NotStrictlyPositiveException(initValue);
|
||||
}
|
||||
if (slope >= 0) {
|
||||
throw new NumberIsTooLargeException(slope, 0, false);
|
||||
}
|
||||
if (numCall <= 1) {
|
||||
throw new NotStrictlyPositiveException(numCall);
|
||||
}
|
||||
|
||||
final double k = initValue;
|
||||
final double m = numCall;
|
||||
final double b = 4 * slope / initValue;
|
||||
final double q = 1;
|
||||
final double a = 0;
|
||||
final double n = 1;
|
||||
sigmoid = new Logistic(k, m, b, q, a, n);
|
||||
|
||||
final double y0 = sigmoid.value(0);
|
||||
scale = k / y0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the value of the learning factor.
|
||||
*
|
||||
* @param numCall Current step of the training task.
|
||||
* @return the value of the function at {@code numCall}.
|
||||
*/
|
||||
public double value(long numCall) {
|
||||
return scale * sigmoid.value(numCall);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Miscellaneous utilities.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm.util;
|
|
@ -0,0 +1,433 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.twod;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.io.Serializable;
|
||||
import java.io.ObjectInputStream;
|
||||
import org.apache.commons.math3.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
import org.apache.commons.math3.exception.MathInternalError;
|
||||
|
||||
/**
|
||||
* Neural network with the topology of a two-dimensional surface.
|
||||
* Each neuron defines one surface element.
|
||||
* <br/>
|
||||
* This network is primarily intended to represent a
|
||||
* <a href="http://en.wikipedia.org/wiki/Kohonen">
|
||||
* Self Organizing Feature Map</a>.
|
||||
*
|
||||
* @see org.apache.commons.math3.ml.neuralnet.sofm
|
||||
* @version $Id$
|
||||
*/
|
||||
public class NeuronSquareMesh2D implements Serializable {
|
||||
/** Underlying network. */
|
||||
private final Network network;
|
||||
/** Number of rows. */
|
||||
private final int numberOfRows;
|
||||
/** Number of columns. */
|
||||
private final int numberOfColumns;
|
||||
/** Wrap. */
|
||||
private final boolean wrapRows;
|
||||
/** Wrap. */
|
||||
private final boolean wrapColumns;
|
||||
/** Neighbourhood type. */
|
||||
private final SquareNeighbourhood neighbourhood;
|
||||
/**
|
||||
* Mapping of the 2D coordinates (in the rectangular mesh) to
|
||||
* the neuron identifiers (attributed by the {@link #network}
|
||||
* instance).
|
||||
*/
|
||||
private final long[][] identifiers;
|
||||
|
||||
/**
|
||||
* Constructor with restricted access, solely used for deserialization.
|
||||
*
|
||||
* @param wrapRowDim Whether to wrap the first dimension (i.e the first
|
||||
* and last neurons will be linked together).
|
||||
* @param wrapColDim Whether to wrap the second dimension (i.e the first
|
||||
* and last neurons will be linked together).
|
||||
* @param neighbourhoodType Neighbourhood type.
|
||||
* @param featuresList Arrays that will initialize the features sets of
|
||||
* the network's neurons.
|
||||
* @throws NumberIsTooSmallException if {@code numRows < 2} or
|
||||
* {@code numCols < 2}.
|
||||
*/
|
||||
NeuronSquareMesh2D(boolean wrapRowDim,
|
||||
boolean wrapColDim,
|
||||
SquareNeighbourhood neighbourhoodType,
|
||||
double[][][] featuresList) {
|
||||
numberOfRows = featuresList.length;
|
||||
numberOfColumns = featuresList[0].length;
|
||||
|
||||
if (numberOfRows < 2) {
|
||||
throw new NumberIsTooSmallException(numberOfRows, 2, true);
|
||||
}
|
||||
if (numberOfColumns < 2) {
|
||||
throw new NumberIsTooSmallException(numberOfColumns, 2, true);
|
||||
}
|
||||
|
||||
wrapRows = wrapRowDim;
|
||||
wrapColumns = wrapColDim;
|
||||
neighbourhood = neighbourhoodType;
|
||||
|
||||
final int fLen = featuresList[0][0].length;
|
||||
network = new Network(0, fLen);
|
||||
identifiers = new long[numberOfRows][numberOfColumns];
|
||||
|
||||
// Add neurons.
|
||||
for (int i = 0; i < numberOfRows; i++) {
|
||||
for (int j = 0; j < numberOfColumns; j++) {
|
||||
identifiers[i][j] = network.createNeuron(featuresList[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Add links.
|
||||
createLinks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a two-dimensional network composed of square cells:
|
||||
* Each neuron not located on the border of the mesh has four
|
||||
* neurons linked to it.
|
||||
* <br/>
|
||||
* The links are bi-directional.
|
||||
* <br/>
|
||||
* The topology of the network can also be a cylinder (if one
|
||||
* of the dimensions is wrapped) or a torus (if both dimensions
|
||||
* are wrapped).
|
||||
*
|
||||
* @param numRows Number of neurons in the first dimension.
|
||||
* @param wrapRowDim Whether to wrap the first dimension (i.e the first
|
||||
* and last neurons will be linked together).
|
||||
* @param numCols Number of neurons in the second dimension.
|
||||
* @param wrapColDim Whether to wrap the second dimension (i.e the first
|
||||
* and last neurons will be linked together).
|
||||
* @param neighbourhoodType Neighbourhood type.
|
||||
* @param featureInit Array of functions that will initialize the
|
||||
* corresponding element of the features set of each newly created
|
||||
* neuron. In particular, the size of this array defines the size of
|
||||
* feature set.
|
||||
* @throws NumberIsTooSmallException if {@code numRows < 2} or
|
||||
* {@code numCols < 2}.
|
||||
*/
|
||||
public NeuronSquareMesh2D(int numRows,
|
||||
boolean wrapRowDim,
|
||||
int numCols,
|
||||
boolean wrapColDim,
|
||||
SquareNeighbourhood neighbourhoodType,
|
||||
FeatureInitializer[] featureInit) {
|
||||
if (numRows < 2) {
|
||||
throw new NumberIsTooSmallException(numRows, 2, true);
|
||||
}
|
||||
if (numCols < 2) {
|
||||
throw new NumberIsTooSmallException(numCols, 2, true);
|
||||
}
|
||||
|
||||
numberOfRows = numRows;
|
||||
wrapRows = wrapRowDim;
|
||||
numberOfColumns = numCols;
|
||||
wrapColumns = wrapColDim;
|
||||
neighbourhood = neighbourhoodType;
|
||||
identifiers = new long[numberOfRows][numberOfColumns];
|
||||
|
||||
final int fLen = featureInit.length;
|
||||
network = new Network(0, fLen);
|
||||
|
||||
// Add neurons.
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
for (int j = 0; j < numCols; j++) {
|
||||
final double[] features = new double[fLen];
|
||||
for (int fIndex = 0; fIndex < fLen; fIndex++) {
|
||||
features[fIndex] = featureInit[fIndex].value();
|
||||
}
|
||||
identifiers[i][j] = network.createNeuron(features);
|
||||
}
|
||||
}
|
||||
|
||||
// Add links.
|
||||
createLinks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the underlying network.
|
||||
* A reference is returned (enabling, for example, the network to be
|
||||
* trained).
|
||||
* This also implies that calling methods that modify the {@link Network}
|
||||
* topology may cause this class to become inconsistent.
|
||||
*
|
||||
* @return the network.
|
||||
*/
|
||||
public Network getNetwork() {
|
||||
return network;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of neurons in each row of this map.
|
||||
*
|
||||
* @return the number of rows.
|
||||
*/
|
||||
public int getNumberOfRows() {
|
||||
return numberOfRows;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of neurons in each column of this map.
|
||||
*
|
||||
* @return the number of column.
|
||||
*/
|
||||
public int getNumberOfColumns() {
|
||||
return numberOfColumns;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the neuron at location {@code (i, j)} in the map.
|
||||
*
|
||||
* @param i Row index.
|
||||
* @param j Column index.
|
||||
* @return the neuron at {@code (i, j)}.
|
||||
* @throws OutOfRangeException if {@code i} or {@code j} is
|
||||
* out of range.
|
||||
*/
|
||||
public Neuron getNeuron(int i,
|
||||
int j) {
|
||||
if (i < 0 ||
|
||||
i >= numberOfRows) {
|
||||
throw new OutOfRangeException(i, 0, numberOfRows - 1);
|
||||
}
|
||||
if (j < 0 ||
|
||||
j >= numberOfColumns) {
|
||||
throw new OutOfRangeException(j, 0, numberOfColumns - 1);
|
||||
}
|
||||
|
||||
return network.getNeuron(identifiers[i][j]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the neighbour relationships between neurons.
|
||||
*/
|
||||
private void createLinks() {
|
||||
// "linkEnd" will store the identifiers of the "neighbours".
|
||||
final List<Long> linkEnd = new ArrayList<Long>();
|
||||
final int iLast = numberOfRows - 1;
|
||||
final int jLast = numberOfColumns - 1;
|
||||
for (int i = 0; i < numberOfRows; i++) {
|
||||
for (int j = 0; j < numberOfColumns; j++) {
|
||||
linkEnd.clear();
|
||||
|
||||
switch (neighbourhood) {
|
||||
|
||||
case MOORE:
|
||||
// Add links to "diagonal" neighbours.
|
||||
if (i > 0) {
|
||||
if (j > 0) {
|
||||
linkEnd.add(identifiers[i - 1][j - 1]);
|
||||
}
|
||||
if (j < jLast) {
|
||||
linkEnd.add(identifiers[i - 1][j + 1]);
|
||||
}
|
||||
}
|
||||
if (i < iLast) {
|
||||
if (j > 0) {
|
||||
linkEnd.add(identifiers[i + 1][j - 1]);
|
||||
}
|
||||
if (j < jLast) {
|
||||
linkEnd.add(identifiers[i + 1][j + 1]);
|
||||
}
|
||||
}
|
||||
if (wrapRows) {
|
||||
if (i == 0) {
|
||||
if (j > 0) {
|
||||
linkEnd.add(identifiers[iLast][j - 1]);
|
||||
}
|
||||
if (j < jLast) {
|
||||
linkEnd.add(identifiers[iLast][j + 1]);
|
||||
}
|
||||
} else if (i == iLast) {
|
||||
if (j > 0) {
|
||||
linkEnd.add(identifiers[0][j - 1]);
|
||||
}
|
||||
if (j < jLast) {
|
||||
linkEnd.add(identifiers[0][j + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (wrapColumns) {
|
||||
if (j == 0) {
|
||||
if (i > 0) {
|
||||
linkEnd.add(identifiers[i - 1][jLast]);
|
||||
}
|
||||
if (i < iLast) {
|
||||
linkEnd.add(identifiers[i + 1][jLast]);
|
||||
}
|
||||
} else if (j == jLast) {
|
||||
if (i > 0) {
|
||||
linkEnd.add(identifiers[i - 1][0]);
|
||||
}
|
||||
if (i < iLast) {
|
||||
linkEnd.add(identifiers[i + 1][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (wrapRows &&
|
||||
wrapColumns) {
|
||||
if (i == 0 &&
|
||||
j == 0) {
|
||||
linkEnd.add(identifiers[iLast][jLast]);
|
||||
} else if (i == 0 &&
|
||||
j == jLast) {
|
||||
linkEnd.add(identifiers[iLast][0]);
|
||||
} else if (i == iLast &&
|
||||
j == 0) {
|
||||
linkEnd.add(identifiers[0][jLast]);
|
||||
} else if (i == iLast &&
|
||||
j == jLast) {
|
||||
linkEnd.add(identifiers[0][0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Case falls through since the "Moore" neighbourhood
|
||||
// also contains the neurons that belong to the "Von
|
||||
// Neumann" neighbourhood.
|
||||
|
||||
// fallthru (CheckStyle)
|
||||
case VON_NEUMANN:
|
||||
// Links to preceding and following "row".
|
||||
if (i > 0) {
|
||||
linkEnd.add(identifiers[i - 1][j]);
|
||||
}
|
||||
if (i < iLast) {
|
||||
linkEnd.add(identifiers[i + 1][j]);
|
||||
}
|
||||
if (wrapRows) {
|
||||
if (i == 0) {
|
||||
linkEnd.add(identifiers[iLast][j]);
|
||||
} else if (i == iLast) {
|
||||
linkEnd.add(identifiers[0][j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Links to preceding and following "column".
|
||||
if (j > 0) {
|
||||
linkEnd.add(identifiers[i][j - 1]);
|
||||
}
|
||||
if (j < jLast) {
|
||||
linkEnd.add(identifiers[i][j + 1]);
|
||||
}
|
||||
if (wrapColumns) {
|
||||
if (j == 0) {
|
||||
linkEnd.add(identifiers[i][jLast]);
|
||||
} else if (j == jLast) {
|
||||
linkEnd.add(identifiers[i][0]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new MathInternalError(); // Cannot happen.
|
||||
}
|
||||
|
||||
final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
|
||||
for (long b : linkEnd) {
|
||||
final Neuron bNeuron = network.getNeuron(b);
|
||||
// Link to all neighbours.
|
||||
// The reverse links will be added as the loop proceeds.
|
||||
network.addLink(aNeuron, bNeuron);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prevents proxy bypass.
|
||||
*
|
||||
* @param in Input stream.
|
||||
*/
|
||||
private void readObject(ObjectInputStream in) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the proxy instance that will be actually serialized.
|
||||
*/
|
||||
private Object writeReplace() {
|
||||
final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
|
||||
for (int i = 0; i < numberOfRows; i++) {
|
||||
for (int j = 0; j < numberOfColumns; j++) {
|
||||
featuresList[i][j] = getNeuron(i, j).getFeatures();
|
||||
}
|
||||
}
|
||||
|
||||
return new SerializationProxy(wrapRows,
|
||||
wrapColumns,
|
||||
neighbourhood,
|
||||
featuresList);
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialization.
|
||||
*/
|
||||
private static class SerializationProxy implements Serializable {
|
||||
/** Serializable. */
|
||||
private static final long serialVersionUID = 20130226L;
|
||||
/** Wrap. */
|
||||
private final boolean wrapRows;
|
||||
/** Wrap. */
|
||||
private final boolean wrapColumns;
|
||||
/** Neighbourhood type. */
|
||||
private final SquareNeighbourhood neighbourhood;
|
||||
/** Neurons' features. */
|
||||
private final double[][][] featuresList;
|
||||
|
||||
/**
|
||||
* @param wrapRows Whether the row dimension is wrapped.
|
||||
* @param wrapColumns Whether the column dimension is wrapped.
|
||||
* @param neighbourhood Neighbourhood type.
|
||||
* @param featuresList List of neurons features.
|
||||
* {@code neuronList}.
|
||||
*/
|
||||
SerializationProxy(boolean wrapRows,
|
||||
boolean wrapColumns,
|
||||
SquareNeighbourhood neighbourhood,
|
||||
double[][][] featuresList) {
|
||||
this.wrapRows = wrapRows;
|
||||
this.wrapColumns = wrapColumns;
|
||||
this.neighbourhood = neighbourhood;
|
||||
this.featuresList = featuresList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom serialization.
|
||||
*
|
||||
* @return the {@link Neuron} for which this instance is the proxy.
|
||||
*/
|
||||
private Object readResolve() {
|
||||
return new NeuronSquareMesh2D(wrapRows,
|
||||
wrapColumns,
|
||||
neighbourhood,
|
||||
featuresList);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Two-dimensional neural networks.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.twod;
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.ml.neuralnet.oned.NeuronString;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link MapUtils} class.
|
||||
*/
|
||||
public class MapUtilsTest {
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
*/
|
||||
@Test
|
||||
public void testFindClosestNeuron() {
|
||||
final FeatureInitializer init
|
||||
= new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(-0.1, 0.1));
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
|
||||
final Network net = new NeuronString(3, false, initArray).getNetwork();
|
||||
final DistanceMeasure dist = new EuclideanDistance();
|
||||
|
||||
final Set<Neuron> allBest = new HashSet<Neuron>();
|
||||
final Set<Neuron> best = new HashSet<Neuron>();
|
||||
double[][] features;
|
||||
|
||||
// The following tests ensures that
|
||||
// 1. the same neuron is always selected when the input feature is
|
||||
// in the range of the initializer,
|
||||
// 2. different network's neuron have been selected by inputs features
|
||||
// that belong to different ranges.
|
||||
|
||||
best.clear();
|
||||
features = new double[][] {
|
||||
{ -1 },
|
||||
{ 0.4 },
|
||||
};
|
||||
for (double[] f : features) {
|
||||
best.add(MapUtils.findBest(f, net, dist));
|
||||
}
|
||||
Assert.assertEquals(1, best.size());
|
||||
allBest.addAll(best);
|
||||
|
||||
best.clear();
|
||||
features = new double[][] {
|
||||
{ 0.6 },
|
||||
{ 1.4 },
|
||||
};
|
||||
for (double[] f : features) {
|
||||
best.add(MapUtils.findBest(f, net, dist));
|
||||
}
|
||||
Assert.assertEquals(1, best.size());
|
||||
allBest.addAll(best);
|
||||
|
||||
best.clear();
|
||||
features = new double[][] {
|
||||
{ 1.6 },
|
||||
{ 3 },
|
||||
};
|
||||
for (double[] f : features) {
|
||||
best.add(MapUtils.findBest(f, net, dist));
|
||||
}
|
||||
Assert.assertEquals(1, best.size());
|
||||
allBest.addAll(best);
|
||||
|
||||
Assert.assertEquals(3, allBest.size());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.NoSuchElementException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
|
||||
import org.apache.commons.math3.random.Well44497b;
|
||||
|
||||
/**
|
||||
* Tests for {@link Network}.
|
||||
*/
|
||||
public class NetworkTest {
|
||||
final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
|
||||
|
||||
@Test
|
||||
public void testGetFeaturesSize() {
|
||||
final FeatureInitializer[] initArray = { init, init, init };
|
||||
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
Assert.assertEquals(3, net.getFeaturesSize());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1
|
||||
* | |
|
||||
* | |
|
||||
* 2-----3
|
||||
*/
|
||||
@Test
|
||||
public void testDeleteLink() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Delete 0 --> 1.
|
||||
net.deleteLink(net.getNeuron(0),
|
||||
net.getNeuron(1));
|
||||
|
||||
// Link from 0 to 1 was deleted.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
Assert.assertFalse(neighbours.contains(net.getNeuron(1)));
|
||||
// Link from 1 to 0 still exists.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(0)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1
|
||||
* | |
|
||||
* | |
|
||||
* 2-----3
|
||||
*/
|
||||
@Test
|
||||
public void testDeleteNeuron() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
|
||||
Assert.assertEquals(2, net.getNeighbours(net.getNeuron(0)).size());
|
||||
Assert.assertEquals(2, net.getNeighbours(net.getNeuron(3)).size());
|
||||
|
||||
// Delete neuron 1.
|
||||
net.deleteNeuron(net.getNeuron(1));
|
||||
|
||||
try {
|
||||
net.getNeuron(1);
|
||||
} catch (NoSuchElementException expected) {}
|
||||
|
||||
Assert.assertEquals(1, net.getNeighbours(net.getNeuron(0)).size());
|
||||
Assert.assertEquals(1, net.getNeighbours(net.getNeuron(3)).size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIterationOrder() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(4, false,
|
||||
3, true,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
|
||||
boolean isUnspecifiedOrder = false;
|
||||
|
||||
// Check that the default iterator returns the neurons
|
||||
// in an unspecified order.
|
||||
long previousId = Long.MIN_VALUE;
|
||||
for (Neuron n : net) {
|
||||
final long currentId = n.getIdentifier();
|
||||
if (currentId < previousId) {
|
||||
isUnspecifiedOrder = true;
|
||||
break;
|
||||
}
|
||||
previousId = currentId;
|
||||
}
|
||||
Assert.assertTrue(isUnspecifiedOrder);
|
||||
|
||||
// Check that the comparator provides a specific order.
|
||||
isUnspecifiedOrder = false;
|
||||
previousId = Long.MIN_VALUE;
|
||||
for (Neuron n : net.getNeurons(new Network.NeuronIdentifierComparator())) {
|
||||
final long currentId = n.getIdentifier();
|
||||
if (currentId < previousId) {
|
||||
isUnspecifiedOrder = true;
|
||||
break;
|
||||
}
|
||||
previousId = currentId;
|
||||
}
|
||||
Assert.assertFalse(isUnspecifiedOrder);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize()
|
||||
throws IOException,
|
||||
ClassNotFoundException {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network out = new NeuronSquareMesh2D(4, false,
|
||||
3, true,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
|
||||
final ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
final ObjectOutputStream oos = new ObjectOutputStream(bos);
|
||||
oos.writeObject(out);
|
||||
|
||||
final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
|
||||
final ObjectInputStream ois = new ObjectInputStream(bis);
|
||||
final Network in = (Network) ois.readObject();
|
||||
|
||||
for (Neuron nOut : out) {
|
||||
final Neuron nIn = in.getNeuron(nOut.getIdentifier());
|
||||
|
||||
// Same values.
|
||||
final double[] outF = nOut.getFeatures();
|
||||
final double[] inF = nIn.getFeatures();
|
||||
Assert.assertEquals(outF.length, inF.length);
|
||||
for (int i = 0; i < outF.length; i++) {
|
||||
Assert.assertEquals(outF[i], inF[i], 0d);
|
||||
}
|
||||
|
||||
// Same neighbours.
|
||||
final Collection<Neuron> outNeighbours = out.getNeighbours(nOut);
|
||||
final Collection<Neuron> inNeighbours = in.getNeighbours(nIn);
|
||||
Assert.assertEquals(outNeighbours.size(), inNeighbours.size());
|
||||
for (Neuron oN : outNeighbours) {
|
||||
Assert.assertTrue(inNeighbours.contains(in.getNeuron(oN.getIdentifier())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.IOException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link Neuron}.
|
||||
*/
|
||||
public class NeuronTest {
|
||||
@Test
|
||||
public void testGetIdentifier() {
|
||||
final long id = 1234567;
|
||||
final Neuron n = new Neuron(id, new double[] { 0 });
|
||||
|
||||
Assert.assertEquals(id, n.getIdentifier());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetSize() {
|
||||
final double[] features = { -1, -1e-97, 0, 23.456, 9.01e203 } ;
|
||||
final Neuron n = new Neuron(1, features);
|
||||
Assert.assertEquals(features.length, n.getSize());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatures() {
|
||||
final double[] features = { -1, -1e-97, 0, 23.456, 9.01e203 } ;
|
||||
final Neuron n = new Neuron(1, features);
|
||||
|
||||
final double[] f = n.getFeatures();
|
||||
// Accessor returns a copy.
|
||||
Assert.assertFalse(f == features);
|
||||
|
||||
// Values are the same.
|
||||
Assert.assertEquals(features.length, f.length);
|
||||
for (int i = 0; i < features.length; i++) {
|
||||
Assert.assertEquals(features[i], f[i], 0d);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompareAndSetFeatures() {
|
||||
final Neuron n = new Neuron(1, new double[] { 0 });
|
||||
double[] expect = n.getFeatures();
|
||||
double[] update = new double[] { expect[0] + 1.23 };
|
||||
|
||||
// Test "success".
|
||||
boolean ok = n.compareAndSetFeatures(expect, update);
|
||||
// Check that the update is reported as successful.
|
||||
Assert.assertTrue(ok);
|
||||
// Check that the new value is correct.
|
||||
Assert.assertEquals(update[0], n.getFeatures()[0], 0d);
|
||||
|
||||
// Test "failure".
|
||||
double[] update1 = new double[] { update[0] + 4.56 };
|
||||
// Must return "false" because the neuron has been
|
||||
// updated: a new update can only succeed if "expect"
|
||||
// is set to the new features.
|
||||
ok = n.compareAndSetFeatures(expect, update1);
|
||||
// Check that the update is reported as failed.
|
||||
Assert.assertFalse(ok);
|
||||
// Check that the value was not changed.
|
||||
Assert.assertEquals(update[0], n.getFeatures()[0], 0d);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize()
|
||||
throws IOException,
|
||||
ClassNotFoundException {
|
||||
final Neuron out = new Neuron(123, new double[] { -98.76, -1, 0, 1e-23, 543.21, 1e234 });
|
||||
final ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
final ObjectOutputStream oos = new ObjectOutputStream(bos);
|
||||
oos.writeObject(out);
|
||||
|
||||
final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
|
||||
final ObjectInputStream ois = new ObjectInputStream(bis);
|
||||
final Neuron in = (Neuron) ois.readObject();
|
||||
|
||||
// Same identifier.
|
||||
Assert.assertEquals(out.getIdentifier(),
|
||||
in.getIdentifier());
|
||||
// Same values.
|
||||
final double[] outF = out.getFeatures();
|
||||
final double[] inF = in.getFeatures();
|
||||
Assert.assertEquals(outF.length, inF.length);
|
||||
for (int i = 0; i < outF.length; i++) {
|
||||
Assert.assertEquals(outF[i], inF[i], 0d);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.Well44497b;
|
||||
|
||||
/**
|
||||
* Wraps a given initializer.
|
||||
*/
|
||||
public class OffsetFeatureInitializer
|
||||
implements FeatureInitializer {
|
||||
/** Wrapped initializer. */
|
||||
private final FeatureInitializer orig;
|
||||
/** Offset. */
|
||||
private int inc = 0;
|
||||
|
||||
/**
|
||||
* Creates a new initializer whose {@link #value()} method
|
||||
* will return {@code orig.value() + offset}, where
|
||||
* {@code offset} is automatically incremented by one at
|
||||
* each call.
|
||||
*
|
||||
* @param orig Original initializer.
|
||||
*/
|
||||
public OffsetFeatureInitializer(FeatureInitializer orig) {
|
||||
this.orig = orig;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double value() {
|
||||
return orig.value() + inc++;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.oned;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math3.random.Well44497b;
|
||||
|
||||
/**
|
||||
* Tests for {@link NeuronString} and {@link Network} functionality for
|
||||
* a one-dimensional network.
|
||||
*/
|
||||
public class NeuronStringTest {
|
||||
final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2-----3
|
||||
*/
|
||||
@Test
|
||||
public void testSegmentNetwork() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronString(4, false, initArray).getNetwork();
|
||||
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neuron 0.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
for (long nId : new long[] { 1 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(1, neighbours.size());
|
||||
|
||||
// Neuron 1.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
for (long nId : new long[] { 0, 2 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
|
||||
// Neuron 2.
|
||||
neighbours = net.getNeighbours(net.getNeuron(2));
|
||||
for (long nId : new long[] { 1, 3 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
|
||||
// Neuron 3.
|
||||
neighbours = net.getNeighbours(net.getNeuron(3));
|
||||
for (long nId : new long[] { 2 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(1, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2-----3
|
||||
*/
|
||||
@Test
|
||||
public void testCircleNetwork() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronString(4, true, initArray).getNetwork();
|
||||
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neuron 0.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
for (long nId : new long[] { 1, 3 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
|
||||
// Neuron 1.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
for (long nId : new long[] { 0, 2 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
|
||||
// Neuron 2.
|
||||
neighbours = net.getNeighbours(net.getNeuron(2));
|
||||
for (long nId : new long[] { 1, 3 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
|
||||
// Neuron 3.
|
||||
neighbours = net.getNeighbours(net.getNeuron(3));
|
||||
for (long nId : new long[] { 0, 2 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2-----3-----4
|
||||
*/
|
||||
@Test
|
||||
public void testGetNeighboursWithExclude() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronString(5, true, initArray).getNetwork();
|
||||
final Collection<Neuron> exclude = new ArrayList<Neuron>();
|
||||
exclude.add(net.getNeuron(1));
|
||||
final Collection<Neuron> neighbours = net.getNeighbours(net.getNeuron(0),
|
||||
exclude);
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(4)));
|
||||
Assert.assertEquals(1, neighbours.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize()
|
||||
throws IOException,
|
||||
ClassNotFoundException {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final NeuronString out = new NeuronString(4, false, initArray);
|
||||
|
||||
final ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
final ObjectOutputStream oos = new ObjectOutputStream(bos);
|
||||
oos.writeObject(out);
|
||||
|
||||
final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
|
||||
final ObjectInputStream ois = new ObjectInputStream(bis);
|
||||
final NeuronString in = (NeuronString) ois.readObject();
|
||||
|
||||
for (Neuron nOut : out.getNetwork()) {
|
||||
final Neuron nIn = in.getNetwork().getNeuron(nOut.getIdentifier());
|
||||
|
||||
// Same values.
|
||||
final double[] outF = nOut.getFeatures();
|
||||
final double[] inF = nIn.getFeatures();
|
||||
Assert.assertEquals(outF.length, inF.length);
|
||||
for (int i = 0; i < outF.length; i++) {
|
||||
Assert.assertEquals(outF[i], inF[i], 0d);
|
||||
}
|
||||
|
||||
// Same neighbours.
|
||||
final Collection<Neuron> outNeighbours = out.getNetwork().getNeighbours(nOut);
|
||||
final Collection<Neuron> inNeighbours = in.getNetwork().getNeighbours(nIn);
|
||||
Assert.assertEquals(outNeighbours.size(), inNeighbours.size());
|
||||
for (Neuron oN : outNeighbours) {
|
||||
Assert.assertTrue(inNeighbours.contains(in.getNetwork().getNeuron(oN.getIdentifier())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.io.PrintWriter;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.apache.commons.math3.RetryRunner;
|
||||
import org.apache.commons.math3.Retry;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
|
||||
|
||||
/**
|
||||
* Tests for {@link KohonenTrainingTask}
|
||||
*/
|
||||
@RunWith(RetryRunner.class)
|
||||
public class KohonenTrainingTaskTest {
|
||||
@Test
|
||||
public void testTravellerSalesmanSquareTourSequentialSolver() {
|
||||
// Cities (in optimal travel order).
|
||||
final City[] squareOfCities = new City[] {
|
||||
new City("o0", 0, 0),
|
||||
new City("o1", 1, 0),
|
||||
new City("o2", 2, 0),
|
||||
new City("o3", 3, 0),
|
||||
new City("o4", 3, 1),
|
||||
new City("o5", 3, 2),
|
||||
new City("o6", 3, 3),
|
||||
new City("o7", 2, 3),
|
||||
new City("o8", 1, 3),
|
||||
new City("o9", 0, 3),
|
||||
new City("i3", 1, 2),
|
||||
new City("i2", 2, 2),
|
||||
new City("i1", 2, 1),
|
||||
new City("i0", 1, 1),
|
||||
};
|
||||
|
||||
final TravellingSalesmanSolver solver = new TravellingSalesmanSolver(squareOfCities, 2);
|
||||
// printSummary("before.travel.seq.dat", solver);
|
||||
solver.createSequentialTask(15000).run();
|
||||
// printSummary("after.travel.seq.dat", solver);
|
||||
final City[] result = solver.getCityList();
|
||||
Assert.assertEquals(squareOfCities.length,
|
||||
uniqueCities(result).size());
|
||||
final double ratio = computeTravelDistance(squareOfCities) / computeTravelDistance(result);
|
||||
Assert.assertEquals(1, ratio, 1e-1); // We do not require the optimal travel.
|
||||
}
|
||||
|
||||
// Test can sometimes fail: Run several times.
|
||||
@Test
|
||||
@Retry
|
||||
public void testTravellerSalesmanSquareTourParallelSolver() throws ExecutionException {
|
||||
// Cities (in optimal travel order).
|
||||
final City[] squareOfCities = new City[] {
|
||||
new City("o0", 0, 0),
|
||||
new City("o1", 1, 0),
|
||||
new City("o2", 2, 0),
|
||||
new City("o3", 3, 0),
|
||||
new City("o4", 3, 1),
|
||||
new City("o5", 3, 2),
|
||||
new City("o6", 3, 3),
|
||||
new City("o7", 2, 3),
|
||||
new City("o8", 1, 3),
|
||||
new City("o9", 0, 3),
|
||||
new City("i3", 1, 2),
|
||||
new City("i2", 2, 2),
|
||||
new City("i1", 2, 1),
|
||||
new City("i0", 1, 1),
|
||||
};
|
||||
|
||||
final TravellingSalesmanSolver solver = new TravellingSalesmanSolver(squareOfCities, 2);
|
||||
// printSummary("before.travel.par.dat", solver);
|
||||
|
||||
// Parallel execution.
|
||||
final ExecutorService service = Executors.newCachedThreadPool();
|
||||
final Runnable[] tasks = solver.createParallelTasks(3, 5000);
|
||||
final List<Future<?>> execOutput = new ArrayList<Future<?>>();
|
||||
// Run tasks.
|
||||
for (Runnable r : tasks) {
|
||||
execOutput.add(service.submit(r));
|
||||
}
|
||||
// Wait for completion (ignoring return value).
|
||||
try {
|
||||
for (Future<?> f : execOutput) {
|
||||
f.get();
|
||||
}
|
||||
} catch (InterruptedException ignored) {}
|
||||
// Terminate all threads.
|
||||
service.shutdown();
|
||||
|
||||
// printSummary("after.travel.par.dat", solver);
|
||||
final City[] result = solver.getCityList();
|
||||
Assert.assertEquals(squareOfCities.length,
|
||||
uniqueCities(result).size());
|
||||
final double ratio = computeTravelDistance(squareOfCities) / computeTravelDistance(result);
|
||||
Assert.assertEquals(1, ratio, 1e-1); // We do not require the optimal travel.
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a map of the travel suggested by the solver.
|
||||
*
|
||||
* @param solver Solver.
|
||||
* @return a 4-columns table: {@code <x (neuron)> <y (neuron)> <x (city)> <y (city)>}.
|
||||
*/
|
||||
private String travelCoordinatesTable(TravellingSalesmanSolver solver) {
|
||||
final StringBuilder s = new StringBuilder();
|
||||
for (double[] c : solver.getCoordinatesList()) {
|
||||
s.append(c[0]).append(" ").append(c[1]).append(" ");
|
||||
final City city = solver.getClosestCity(c[0], c[1]);
|
||||
final double[] cityCoord = city.getCoordinates();
|
||||
s.append(cityCoord[0]).append(" ").append(cityCoord[1]).append(" ");
|
||||
s.append(" # ").append(city.getName()).append("\n");
|
||||
}
|
||||
return s.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the distance covered by the salesman, including
|
||||
* the trip back (from the last to first city).
|
||||
*
|
||||
* @param cityList List of cities visited during the travel.
|
||||
* @return the total distance.
|
||||
*/
|
||||
private Collection<City> uniqueCities(City[] cityList) {
|
||||
final Set<City> unique = new HashSet<City>();
|
||||
for (City c : cityList) {
|
||||
unique.add(c);
|
||||
}
|
||||
return unique;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the distance covered by the salesman, including
|
||||
* the trip back (from the last to first city).
|
||||
*
|
||||
* @param cityList List of cities visited during the travel.
|
||||
* @return the total distance.
|
||||
*/
|
||||
private double computeTravelDistance(City[] cityList) {
|
||||
double dist = 0;
|
||||
for (int i = 0; i < cityList.length; i++) {
|
||||
final double[] currentCoord = cityList[i].getCoordinates();
|
||||
final double[] nextCoord = cityList[(i + 1) % cityList.length].getCoordinates();
|
||||
|
||||
final double xDiff = currentCoord[0] - nextCoord[0];
|
||||
final double yDiff = currentCoord[1] - nextCoord[1];
|
||||
|
||||
dist += FastMath.sqrt(xDiff * xDiff + yDiff * yDiff);
|
||||
}
|
||||
|
||||
return dist;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prints a summary of the current state of the solver to the
|
||||
* given filename.
|
||||
*
|
||||
* @param filename File.
|
||||
* @param solver Solver.
|
||||
*/
|
||||
private void printSummary(String filename,
|
||||
TravellingSalesmanSolver solver) {
|
||||
PrintWriter out = null;
|
||||
try {
|
||||
out = new PrintWriter(filename);
|
||||
out.println(travelCoordinatesTable(solver));
|
||||
|
||||
final City[] result = solver.getCityList();
|
||||
out.println("# Number of unique cities: " + uniqueCities(result).size());
|
||||
out.println("# Travel distance: " + computeTravelDistance(result));
|
||||
} catch (Exception e) {
|
||||
// Do nothing.
|
||||
} finally {
|
||||
if (out != null) {
|
||||
out.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import org.apache.commons.math3.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.MapUtils;
|
||||
import org.apache.commons.math3.ml.neuralnet.UpdateAction;
|
||||
import org.apache.commons.math3.ml.neuralnet.OffsetFeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.ml.neuralnet.oned.NeuronString;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link KohonenUpdateAction} class.
|
||||
*/
|
||||
public class KohonenUpdateActionTest {
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
*/
|
||||
@Test
|
||||
public void testUpdate() {
|
||||
final FeatureInitializer init
|
||||
= new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(0, 0.1));
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
|
||||
final int netSize = 3;
|
||||
final Network net = new NeuronString(netSize, false, initArray).getNetwork();
|
||||
final DistanceMeasure dist = new EuclideanDistance();
|
||||
final LearningFactorFunction learning
|
||||
= LearningFactorFunctionFactory.exponentialDecay(1, 0.1, 100);
|
||||
final NeighbourhoodSizeFunction neighbourhood
|
||||
= NeighbourhoodSizeFunctionFactory.exponentialDecay(3, 1, 100);
|
||||
final UpdateAction update = new KohonenUpdateAction(dist, learning, neighbourhood);
|
||||
|
||||
// The following test ensures that, after one "update",
|
||||
// 1. when the initial learning rate equal to 1, the best matching
|
||||
// neuron's features are mapped to the input's features,
|
||||
// 2. when the initial neighbourhood is larger than the network's size,
|
||||
// all neuron's features get closer to the input's features.
|
||||
|
||||
final double[] features = new double[] { 0.3 };
|
||||
final double[] distancesBefore = new double[netSize];
|
||||
int count = 0;
|
||||
for (Neuron n : net) {
|
||||
distancesBefore[count++] = dist.compute(n.getFeatures(), features);
|
||||
}
|
||||
final Neuron bestBefore = MapUtils.findBest(features, net, dist);
|
||||
|
||||
// Initial distance from the best match is larger than zero.
|
||||
Assert.assertTrue(dist.compute(bestBefore.getFeatures(), features) >= 0.2 * 0.2);
|
||||
|
||||
update.update(net, features);
|
||||
|
||||
final double[] distancesAfter = new double[netSize];
|
||||
count = 0;
|
||||
for (Neuron n : net) {
|
||||
distancesAfter[count++] = dist.compute(n.getFeatures(), features);
|
||||
}
|
||||
final Neuron bestAfter = MapUtils.findBest(features, net, dist);
|
||||
|
||||
Assert.assertEquals(bestBefore, bestAfter);
|
||||
// Distance is now zero.
|
||||
Assert.assertEquals(0, dist.compute(bestAfter.getFeatures(), features), 0d);
|
||||
|
||||
for (int i = 0; i < netSize; i++) {
|
||||
// All distances have decreased.
|
||||
Assert.assertTrue(distancesAfter[i] < distancesBefore[i]);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link LearningFactorFunctionFactory} class.
|
||||
*/
|
||||
public class LearningFactorFunctionFactoryTest {
|
||||
@Test(expected=OutOfRangeException.class)
|
||||
public void testExponentialDecayPrecondition0() {
|
||||
LearningFactorFunctionFactory.exponentialDecay(0d, 0d, 2);
|
||||
}
|
||||
@Test(expected=OutOfRangeException.class)
|
||||
public void testExponentialDecayPrecondition1() {
|
||||
LearningFactorFunctionFactory.exponentialDecay(1 + 1e-10, 0d, 2);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testExponentialDecayPrecondition2() {
|
||||
LearningFactorFunctionFactory.exponentialDecay(1d, 0d, 2);
|
||||
}
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testExponentialDecayPrecondition3() {
|
||||
LearningFactorFunctionFactory.exponentialDecay(1d, 1d, 100);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testExponentialDecayPrecondition4() {
|
||||
LearningFactorFunctionFactory.exponentialDecay(1d, 0.2, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExponentialDecayTrivial() {
|
||||
final int n = 65;
|
||||
final double init = 0.5;
|
||||
final double valueAtN = 0.1;
|
||||
final LearningFactorFunction f
|
||||
= LearningFactorFunctionFactory.exponentialDecay(init, valueAtN, n);
|
||||
|
||||
Assert.assertEquals(init, f.value(0), 0d);
|
||||
Assert.assertEquals(valueAtN, f.value(n), 0d);
|
||||
Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
|
||||
}
|
||||
|
||||
@Test(expected=OutOfRangeException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition0() {
|
||||
LearningFactorFunctionFactory.quasiSigmoidDecay(0d, -1d, 2);
|
||||
}
|
||||
@Test(expected=OutOfRangeException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition1() {
|
||||
LearningFactorFunctionFactory.quasiSigmoidDecay(1 + 1e-10, -1d, 2);
|
||||
}
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition3() {
|
||||
LearningFactorFunctionFactory.quasiSigmoidDecay(1d, 0d, 100);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition4() {
|
||||
LearningFactorFunctionFactory.quasiSigmoidDecay(1d, -1d, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuasiSigmoidDecayTrivial() {
|
||||
final int n = 65;
|
||||
final double init = 0.5;
|
||||
final double slope = -1e-1;
|
||||
final LearningFactorFunction f
|
||||
= LearningFactorFunctionFactory.quasiSigmoidDecay(init, slope, n);
|
||||
|
||||
Assert.assertEquals(init, f.value(0), 0d);
|
||||
// Very approximate derivative.
|
||||
Assert.assertEquals(slope, f.value(n) - f.value(n - 1), 1e-2);
|
||||
Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link NeighbourhoodSizeFunctionFactory} class.
|
||||
*/
|
||||
public class NeighbourhoodSizeFunctionFactoryTest {
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testExponentialDecayPrecondition1() {
|
||||
NeighbourhoodSizeFunctionFactory.exponentialDecay(0, 0, 2);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testExponentialDecayPrecondition2() {
|
||||
NeighbourhoodSizeFunctionFactory.exponentialDecay(1, 0, 2);
|
||||
}
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testExponentialDecayPrecondition3() {
|
||||
NeighbourhoodSizeFunctionFactory.exponentialDecay(1, 1, 100);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testExponentialDecayPrecondition4() {
|
||||
NeighbourhoodSizeFunctionFactory.exponentialDecay(2, 1, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExponentialDecayTrivial() {
|
||||
final int n = 65;
|
||||
final int init = 4;
|
||||
final int valueAtN = 3;
|
||||
final NeighbourhoodSizeFunction f
|
||||
= NeighbourhoodSizeFunctionFactory.exponentialDecay(init, valueAtN, n);
|
||||
|
||||
Assert.assertEquals(init, f.value(0));
|
||||
Assert.assertEquals(valueAtN, f.value(n));
|
||||
Assert.assertEquals(0, f.value(Long.MAX_VALUE));
|
||||
}
|
||||
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition1() {
|
||||
NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(0d, -1d, 2);
|
||||
}
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition3() {
|
||||
NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(1d, 0d, 100);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testQuasiSigmoidDecayPrecondition4() {
|
||||
NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(1d, -1d, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuasiSigmoidDecayTrivial() {
|
||||
final int n = 65;
|
||||
final double init = 4;
|
||||
final double slope = -1e-1;
|
||||
final NeighbourhoodSizeFunction f
|
||||
= NeighbourhoodSizeFunctionFactory.quasiSigmoidDecay(init, slope, n);
|
||||
|
||||
Assert.assertEquals(init, f.value(0), 0d);
|
||||
Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,380 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import org.apache.commons.math3.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.ml.neuralnet.oned.NeuronString;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.Well44497b;
|
||||
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.FunctionUtils;
|
||||
import org.apache.commons.math3.analysis.function.HarmonicOscillator;
|
||||
import org.apache.commons.math3.analysis.function.Constant;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
||||
|
||||
/**
|
||||
* Solves the "Travelling Salesman's Problem" (i.e. trying to find the
|
||||
* sequence of cities that minimizes the travel distance) using a 1D
|
||||
* SOFM.
|
||||
*/
|
||||
public class TravellingSalesmanSolver {
|
||||
private static final long FIRST_NEURON_ID = 0;
|
||||
/** RNG. */
|
||||
private final RandomGenerator random = new Well44497b();
|
||||
/** Set of cities. */
|
||||
private final Set<City> cities = new HashSet<City>();
|
||||
/** SOFM. */
|
||||
private final Network net;
|
||||
/** Distance function. */
|
||||
private final DistanceMeasure distance = new EuclideanDistance();
|
||||
/** Total number of neurons. */
|
||||
private final int numberOfNeurons;
|
||||
|
||||
/**
|
||||
* @param cityList List of cities to visit in a single travel.
|
||||
* @param numNeuronsPerCity Number of neurons per city.
|
||||
*/
|
||||
public TravellingSalesmanSolver(City[] cityList,
|
||||
double numNeuronsPerCity) {
|
||||
final double[] xRange = {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY};
|
||||
final double[] yRange = {Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY};
|
||||
|
||||
// Make sure that each city will appear only once in the list.
|
||||
for (City city : cityList) {
|
||||
cities.add(city);
|
||||
}
|
||||
|
||||
// Total number of neurons.
|
||||
numberOfNeurons = (int) numNeuronsPerCity * cities.size();
|
||||
|
||||
// Create a network with circle topology.
|
||||
net = new NeuronString(numberOfNeurons, true, makeInitializers()).getNetwork();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates training tasks.
|
||||
*
|
||||
* @param numTasks Number of tasks to create.
|
||||
* @param numSamplesPerTask Number of training samples per task.
|
||||
* @return the created tasks.
|
||||
*/
|
||||
public Runnable[] createParallelTasks(int numTasks,
|
||||
long numSamplesPerTask) {
|
||||
final Runnable[] tasks = new Runnable[numTasks];
|
||||
final LearningFactorFunction learning
|
||||
= LearningFactorFunctionFactory.exponentialDecay(2e-1,
|
||||
5e-2,
|
||||
numSamplesPerTask / 2);
|
||||
final NeighbourhoodSizeFunction neighbourhood
|
||||
= NeighbourhoodSizeFunctionFactory.exponentialDecay(0.5 * numberOfNeurons,
|
||||
0.1 * numberOfNeurons,
|
||||
numSamplesPerTask / 2);
|
||||
|
||||
for (int i = 0; i < numTasks; i++) {
|
||||
final KohonenUpdateAction action = new KohonenUpdateAction(distance,
|
||||
learning,
|
||||
neighbourhood);
|
||||
tasks[i] = new KohonenTrainingTask(net,
|
||||
createRandomIterator(numSamplesPerTask),
|
||||
action);
|
||||
}
|
||||
|
||||
return tasks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a training task.
|
||||
*
|
||||
* @param numSamples Number of training samples.
|
||||
* @return the created task.
|
||||
*/
|
||||
public Runnable createSequentialTask(long numSamples) {
|
||||
return createParallelTasks(1, numSamples)[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an iterator that will present a series of city's coordinates in
|
||||
* a random order.
|
||||
*
|
||||
* @param numSamples Number of samples.
|
||||
* @return the iterator.
|
||||
*/
|
||||
private Iterator<double[]> createRandomIterator(final long numSamples) {
|
||||
final List<City> cityList = new ArrayList<City>();
|
||||
cityList.addAll(cities);
|
||||
|
||||
return new Iterator<double[]>() {
|
||||
/** Number of samples. */
|
||||
private long n = 0;
|
||||
/** {@inheritDoc} */
|
||||
public boolean hasNext() {
|
||||
return n < numSamples;
|
||||
}
|
||||
/** {@inheritDoc} */
|
||||
public double[] next() {
|
||||
++n;
|
||||
return cityList.get(random.nextInt(cityList.size())).getCoordinates();
|
||||
}
|
||||
/** {@inheritDoc} */
|
||||
public void remove() {
|
||||
throw new MathUnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the list of linked neurons (i.e. the one-dimensional
|
||||
* SOFM).
|
||||
*/
|
||||
private List<Neuron> getNeuronList() {
|
||||
// Sequence of coordinates.
|
||||
final List<Neuron> list = new ArrayList<Neuron>();
|
||||
|
||||
// First neuron.
|
||||
Neuron current = net.getNeuron(FIRST_NEURON_ID);
|
||||
while (true) {
|
||||
list.add(current);
|
||||
final Collection<Neuron> neighbours
|
||||
= net.getNeighbours(current, list);
|
||||
|
||||
final Iterator<Neuron> iter = neighbours.iterator();
|
||||
if (!iter.hasNext()) {
|
||||
// All neurons have been visited.
|
||||
break;
|
||||
}
|
||||
|
||||
current = iter.next();
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the list of features (coordinates) of linked neurons.
|
||||
*/
|
||||
public List<double[]> getCoordinatesList() {
|
||||
// Sequence of coordinates.
|
||||
final List<double[]> coordinatesList = new ArrayList<double[]>();
|
||||
|
||||
for (Neuron n : getNeuronList()) {
|
||||
coordinatesList.add(n.getFeatures());
|
||||
}
|
||||
|
||||
return coordinatesList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the travel proposed by the solver.
|
||||
* Note: cities can be missing or duplicated.
|
||||
*
|
||||
* @return the list of cities in travel order.
|
||||
*/
|
||||
public City[] getCityList() {
|
||||
final List<double[]> coord = getCoordinatesList();
|
||||
final City[] cityList = new City[coord.size()];
|
||||
for (int i = 0; i < cityList.length; i++) {
|
||||
final double[] c = coord.get(i);
|
||||
cityList[i] = getClosestCity(c[0], c[1]);
|
||||
}
|
||||
return cityList;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param x x-coordinate.
|
||||
* @param y y-coordinate.
|
||||
* @return the city whose coordinates are closest to {@code (x, y)}.
|
||||
*/
|
||||
public City getClosestCity(double x,
|
||||
double y) {
|
||||
City closest = null;
|
||||
double min = Double.POSITIVE_INFINITY;
|
||||
for (City c : cities) {
|
||||
final double d = c.distance(x, y);
|
||||
if (d < min) {
|
||||
min = d;
|
||||
closest = c;
|
||||
}
|
||||
}
|
||||
return closest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the barycentre of all city locations.
|
||||
*
|
||||
* @param cities City list.
|
||||
* @return the barycentre.
|
||||
*/
|
||||
private static double[] barycentre(Set<City> cities) {
|
||||
double xB = 0;
|
||||
double yB = 0;
|
||||
|
||||
int count = 0;
|
||||
for (City c : cities) {
|
||||
final double[] coord = c.getCoordinates();
|
||||
xB += coord[0];
|
||||
yB += coord[1];
|
||||
|
||||
++count;
|
||||
}
|
||||
|
||||
return new double[] { xB / count, yB / count };
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the largest distance between the point at coordinates
|
||||
* {@code (x, y)} and any of the cities.
|
||||
*
|
||||
* @param x x-coodinate.
|
||||
* @param y y-coodinate.
|
||||
* @param cities City list.
|
||||
* @return the largest distance.
|
||||
*/
|
||||
private static double largestDistance(double x,
|
||||
double y,
|
||||
Set<City> cities) {
|
||||
double maxDist = 0;
|
||||
for (City c : cities) {
|
||||
final double dist = c.distance(x, y);
|
||||
if (dist > maxDist) {
|
||||
maxDist = dist;
|
||||
}
|
||||
}
|
||||
|
||||
return maxDist;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the features' initializers: an approximate circle around the
|
||||
* barycentre of the cities.
|
||||
*
|
||||
* @return an array containing the two initializers.
|
||||
*/
|
||||
private FeatureInitializer[] makeInitializers() {
|
||||
// Barycentre.
|
||||
final double[] centre = barycentre(cities);
|
||||
// Largest distance from centre.
|
||||
final double radius = 0.5 * largestDistance(centre[0], centre[1], cities);
|
||||
|
||||
final double omega = 2 * Math.PI / numberOfNeurons;
|
||||
final UnivariateFunction h1 = new HarmonicOscillator(radius, omega, 0);
|
||||
final UnivariateFunction h2 = new HarmonicOscillator(radius, omega, 0.5 * Math.PI);
|
||||
|
||||
final UnivariateFunction f1 = FunctionUtils.add(h1, new Constant(centre[0]));
|
||||
final UnivariateFunction f2 = FunctionUtils.add(h2, new Constant(centre[1]));
|
||||
|
||||
final RealDistribution u = new UniformRealDistribution(-0.05 * radius, 0.05 * radius);
|
||||
|
||||
return new FeatureInitializer[] {
|
||||
FeatureInitializerFactory.randomize(u, FeatureInitializerFactory.function(f1, 0, 1)),
|
||||
FeatureInitializerFactory.randomize(u, FeatureInitializerFactory.function(f2, 0, 1))
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A city, represented by a name and two-dimensional coordinates.
|
||||
*/
|
||||
class City {
|
||||
/** Identifier. */
|
||||
final String name;
|
||||
/** x-coordinate. */
|
||||
final double x;
|
||||
/** y-coordinate. */
|
||||
final double y;
|
||||
|
||||
/**
|
||||
* @param name Name.
|
||||
* @param x Cartesian x-coordinate.
|
||||
* @param y Cartesian y-coordinate.
|
||||
*/
|
||||
public City(String name,
|
||||
double x,
|
||||
double y) {
|
||||
this.name = name;
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
/**
|
||||
* @retun the name.
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the (x, y) coordinates.
|
||||
*/
|
||||
public double[] getCoordinates() {
|
||||
return new double[] { x, y };
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the distance between this city and
|
||||
* the given point.
|
||||
*
|
||||
* @param x x-coodinate.
|
||||
* @param y y-coodinate.
|
||||
* @return the distance between {@code (x, y)} and this
|
||||
* city.
|
||||
*/
|
||||
public double distance(double x,
|
||||
double y) {
|
||||
final double xDiff = this.x - x;
|
||||
final double yDiff = this.y - y;
|
||||
|
||||
return FastMath.sqrt(xDiff * xDiff + yDiff * yDiff);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public boolean equals(Object o) {
|
||||
if (o instanceof City) {
|
||||
final City other = (City) o;
|
||||
return x == other.x &&
|
||||
y == other.y;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public int hashCode() {
|
||||
int result = 17;
|
||||
|
||||
final long c1 = Double.doubleToLongBits(x);
|
||||
result = 31 * result + (int) (c1 ^ (c1 >>> 32));
|
||||
|
||||
final long c2 = Double.doubleToLongBits(y);
|
||||
result = 31 * result + (int) (c2 ^ (c2 >>> 32));
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm.util;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link ExponentialDecayFunction} class
|
||||
*/
|
||||
public class ExponentialDecayFunctionTest {
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testPrecondition1() {
|
||||
new ExponentialDecayFunction(0d, 0d, 2);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testPrecondition2() {
|
||||
new ExponentialDecayFunction(1d, 0d, 2);
|
||||
}
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testPrecondition3() {
|
||||
new ExponentialDecayFunction(1d, 1d, 100);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testPrecondition4() {
|
||||
new ExponentialDecayFunction(1d, 0.2, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrivial() {
|
||||
final int n = 65;
|
||||
final double init = 4;
|
||||
final double valueAtN = 3;
|
||||
final ExponentialDecayFunction f = new ExponentialDecayFunction(init, valueAtN, n);
|
||||
|
||||
Assert.assertEquals(init, f.value(0), 0d);
|
||||
Assert.assertEquals(valueAtN, f.value(n), 0d);
|
||||
Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.sofm.util;
|
||||
|
||||
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link QuasiSigmoidDecayFunction} class
|
||||
*/
|
||||
public class QuasiSigmoidDecayFunctionTest {
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testPrecondition1() {
|
||||
new QuasiSigmoidDecayFunction(0d, -1d, 2);
|
||||
}
|
||||
@Test(expected=NumberIsTooLargeException.class)
|
||||
public void testPrecondition3() {
|
||||
new QuasiSigmoidDecayFunction(1d, 0d, 100);
|
||||
}
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testPrecondition4() {
|
||||
new QuasiSigmoidDecayFunction(1d, -1d, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrivial() {
|
||||
final int n = 65;
|
||||
final double init = 4;
|
||||
final double slope = -1e-1;
|
||||
final QuasiSigmoidDecayFunction f = new QuasiSigmoidDecayFunction(init, slope, n);
|
||||
|
||||
Assert.assertEquals(init, f.value(0), 0d);
|
||||
// Very approximate derivative.
|
||||
Assert.assertEquals(slope, f.value(n + 1) - f.value(n), 1e-4);
|
||||
Assert.assertEquals(0, f.value(Long.MAX_VALUE), 0d);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,685 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.ml.neuralnet.twod;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
|
||||
import org.apache.commons.math3.ml.neuralnet.Network;
|
||||
import org.apache.commons.math3.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
|
||||
import org.apache.commons.math3.exception.NumberIsTooSmallException;
|
||||
|
||||
/**
|
||||
* Tests for {@link NeuronSquareMesh2D} and {@link Network} functionality for
|
||||
* a two-dimensional network.
|
||||
*/
|
||||
public class NeuronSquareMesh2DTest {
|
||||
final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
|
||||
|
||||
@Test(expected=NumberIsTooSmallException.class)
|
||||
public void testMinimalNetworkSize1() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
|
||||
new NeuronSquareMesh2D(1, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray);
|
||||
}
|
||||
|
||||
@Test(expected=NumberIsTooSmallException.class)
|
||||
public void testMinimalNetworkSize2() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
|
||||
new NeuronSquareMesh2D(2, false,
|
||||
0, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeaturesSize() {
|
||||
final FeatureInitializer[] initArray = { init, init, init };
|
||||
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
Assert.assertEquals(3, net.getFeaturesSize());
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1
|
||||
* | |
|
||||
* | |
|
||||
* 2-----3
|
||||
*/
|
||||
@Test
|
||||
public void test2x2Network() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neurons 0 and 3.
|
||||
for (long id : new long[] { 0, 3 }) {
|
||||
neighbours = net.getNeighbours(net.getNeuron(id));
|
||||
for (long nId : new long[] { 1, 2 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
}
|
||||
|
||||
// Neurons 1 and 2.
|
||||
for (long id : new long[] { 1, 2 }) {
|
||||
neighbours = net.getNeighbours(net.getNeuron(id));
|
||||
for (long nId : new long[] { 0, 3 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(2, neighbours.size());
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1
|
||||
* | |
|
||||
* | |
|
||||
* 2-----3
|
||||
*/
|
||||
@Test
|
||||
public void test2x2Network2() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.MOORE,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// All neurons
|
||||
for (long id : new long[] { 0, 1, 2, 3 }) {
|
||||
neighbours = net.getNeighbours(net.getNeuron(id));
|
||||
for (long nId : new long[] { 0, 1, 2, 3 }) {
|
||||
if (id != nId) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
* | | |
|
||||
* | | |
|
||||
* 3-----4-----5
|
||||
*/
|
||||
@Test
|
||||
public void test3x2CylinderNetwork() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
3, true,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neuron 0.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
for (long nId : new long[] { 1, 2, 3 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 1.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
for (long nId : new long[] { 0, 2, 4 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 2.
|
||||
neighbours = net.getNeighbours(net.getNeuron(2));
|
||||
for (long nId : new long[] { 0, 1, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 3.
|
||||
neighbours = net.getNeighbours(net.getNeuron(3));
|
||||
for (long nId : new long[] { 0, 4, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 4.
|
||||
neighbours = net.getNeighbours(net.getNeuron(4));
|
||||
for (long nId : new long[] { 1, 3, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 5.
|
||||
neighbours = net.getNeighbours(net.getNeuron(5));
|
||||
for (long nId : new long[] { 2, 3, 4 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
* | | |
|
||||
* | | |
|
||||
* 3-----4-----5
|
||||
*/
|
||||
@Test
|
||||
public void test3x2CylinderNetwork2() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
3, true,
|
||||
SquareNeighbourhood.MOORE,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// All neurons.
|
||||
for (long id : new long[] { 0, 1, 2, 3, 4, 5 }) {
|
||||
neighbours = net.getNeighbours(net.getNeuron(id));
|
||||
for (long nId : new long[] { 0, 1, 2, 3, 4, 5 }) {
|
||||
if (id != nId) {
|
||||
Assert.assertTrue("id=" + id + " nId=" + nId,
|
||||
neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
* | | |
|
||||
* | | |
|
||||
* 3-----4-----5
|
||||
* | | |
|
||||
* | | |
|
||||
* 6-----7-----8
|
||||
*/
|
||||
@Test
|
||||
public void test3x3TorusNetwork() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(3, true,
|
||||
3, true,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neuron 0.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
for (long nId : new long[] { 1, 2, 3, 6 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 1.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
for (long nId : new long[] { 0, 2, 4, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 2.
|
||||
neighbours = net.getNeighbours(net.getNeuron(2));
|
||||
for (long nId : new long[] { 0, 1, 5, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 3.
|
||||
neighbours = net.getNeighbours(net.getNeuron(3));
|
||||
for (long nId : new long[] { 0, 4, 5, 6 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 4.
|
||||
neighbours = net.getNeighbours(net.getNeuron(4));
|
||||
for (long nId : new long[] { 1, 3, 5, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 5.
|
||||
neighbours = net.getNeighbours(net.getNeuron(5));
|
||||
for (long nId : new long[] { 2, 3, 4, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 6.
|
||||
neighbours = net.getNeighbours(net.getNeuron(6));
|
||||
for (long nId : new long[] { 0, 3, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 7.
|
||||
neighbours = net.getNeighbours(net.getNeuron(7));
|
||||
for (long nId : new long[] { 1, 4, 6, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// Neuron 8.
|
||||
neighbours = net.getNeighbours(net.getNeuron(8));
|
||||
for (long nId : new long[] { 2, 5, 6, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
* | | |
|
||||
* | | |
|
||||
* 3-----4-----5
|
||||
* | | |
|
||||
* | | |
|
||||
* 6-----7-----8
|
||||
*/
|
||||
@Test
|
||||
public void test3x3TorusNetwork2() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(3, true,
|
||||
3, true,
|
||||
SquareNeighbourhood.MOORE,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// All neurons.
|
||||
for (long id : new long[] { 0, 1, 2, 3, 4, 5, 6, 7, 8 }) {
|
||||
neighbours = net.getNeighbours(net.getNeuron(id));
|
||||
for (long nId : new long[] { 0, 1, 2, 3, 4, 5, 6, 7, 8 }) {
|
||||
if (id != nId) {
|
||||
Assert.assertTrue("id=" + id + " nId=" + nId,
|
||||
neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
* | | |
|
||||
* | | |
|
||||
* 3-----4-----5
|
||||
* | | |
|
||||
* | | |
|
||||
* 6-----7-----8
|
||||
*/
|
||||
@Test
|
||||
public void test3x3CylinderNetwork() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(3, false,
|
||||
3, true,
|
||||
SquareNeighbourhood.MOORE,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neuron 0.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
for (long nId : new long[] { 1, 2, 3, 4, 5}) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 1.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
for (long nId : new long[] { 0, 2, 3, 4, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 2.
|
||||
neighbours = net.getNeighbours(net.getNeuron(2));
|
||||
for (long nId : new long[] { 0, 1, 3, 4, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 3.
|
||||
neighbours = net.getNeighbours(net.getNeuron(3));
|
||||
for (long nId : new long[] { 0, 1, 2, 4, 5, 6, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(8, neighbours.size());
|
||||
|
||||
// Neuron 4.
|
||||
neighbours = net.getNeighbours(net.getNeuron(4));
|
||||
for (long nId : new long[] { 0, 1, 2, 3, 5, 6, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(8, neighbours.size());
|
||||
|
||||
// Neuron 5.
|
||||
neighbours = net.getNeighbours(net.getNeuron(5));
|
||||
for (long nId : new long[] { 0, 1, 2, 3, 4, 6, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(8, neighbours.size());
|
||||
|
||||
// Neuron 6.
|
||||
neighbours = net.getNeighbours(net.getNeuron(6));
|
||||
for (long nId : new long[] { 3, 4, 5, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 7.
|
||||
neighbours = net.getNeighbours(net.getNeuron(7));
|
||||
for (long nId : new long[] { 3, 4, 5, 6, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 8.
|
||||
neighbours = net.getNeighbours(net.getNeuron(8));
|
||||
for (long nId : new long[] { 3, 4, 5, 6, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
* | | |
|
||||
* | | |
|
||||
* 3-----4-----5
|
||||
* | | |
|
||||
* | | |
|
||||
* 6-----7-----8
|
||||
*/
|
||||
@Test
|
||||
public void test3x3CylinderNetwork2() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(3, false,
|
||||
3, false,
|
||||
SquareNeighbourhood.MOORE,
|
||||
initArray).getNetwork();
|
||||
Collection<Neuron> neighbours;
|
||||
|
||||
// Neuron 0.
|
||||
neighbours = net.getNeighbours(net.getNeuron(0));
|
||||
for (long nId : new long[] { 1, 3, 4}) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 1.
|
||||
neighbours = net.getNeighbours(net.getNeuron(1));
|
||||
for (long nId : new long[] { 0, 2, 3, 4, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 2.
|
||||
neighbours = net.getNeighbours(net.getNeuron(2));
|
||||
for (long nId : new long[] { 1, 4, 5 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 3.
|
||||
neighbours = net.getNeighbours(net.getNeuron(3));
|
||||
for (long nId : new long[] { 0, 1, 4, 6, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 4.
|
||||
neighbours = net.getNeighbours(net.getNeuron(4));
|
||||
for (long nId : new long[] { 0, 1, 2, 3, 5, 6, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(8, neighbours.size());
|
||||
|
||||
// Neuron 5.
|
||||
neighbours = net.getNeighbours(net.getNeuron(5));
|
||||
for (long nId : new long[] { 1, 2, 4, 7, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 6.
|
||||
neighbours = net.getNeighbours(net.getNeuron(6));
|
||||
for (long nId : new long[] { 3, 4, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
|
||||
// Neuron 7.
|
||||
neighbours = net.getNeighbours(net.getNeuron(7));
|
||||
for (long nId : new long[] { 3, 4, 5, 6, 8 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(5, neighbours.size());
|
||||
|
||||
// Neuron 8.
|
||||
neighbours = net.getNeighbours(net.getNeuron(8));
|
||||
for (long nId : new long[] { 4, 5, 7 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(3, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2-----3-----4
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 5-----6-----7-----8-----9
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 10----11----12----13---14
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 15----16----17----18---19
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 20----21----22----23---24
|
||||
*/
|
||||
@Test
|
||||
public void testConcentricNeighbourhood() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(5, true,
|
||||
5, true,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
|
||||
Collection<Neuron> neighbours;
|
||||
Collection<Neuron> exclude = new HashSet<Neuron>();
|
||||
|
||||
// Level-1 neighbourhood.
|
||||
neighbours = net.getNeighbours(net.getNeuron(12));
|
||||
for (long nId : new long[] { 7, 11, 13, 17 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(4, neighbours.size());
|
||||
|
||||
// 1. Add the neuron to the "exclude" list.
|
||||
exclude.add(net.getNeuron(12));
|
||||
// 2. Add all neurons from level-1 neighbourhood.
|
||||
exclude.addAll(neighbours);
|
||||
// 3. Retrieve level-2 neighbourhood.
|
||||
neighbours = net.getNeighbours(neighbours, exclude);
|
||||
for (long nId : new long[] { 6, 8, 16, 18, 2, 10, 14, 22 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(8, neighbours.size());
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2-----3-----4
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 5-----6-----7-----8-----9
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 10----11----12----13---14
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 15----16----17----18---19
|
||||
* | | | | |
|
||||
* | | | | |
|
||||
* 20----21----22----23---24
|
||||
*/
|
||||
@Test
|
||||
public void testConcentricNeighbourhood2() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(5, true,
|
||||
5, true,
|
||||
SquareNeighbourhood.MOORE,
|
||||
initArray).getNetwork();
|
||||
|
||||
Collection<Neuron> neighbours;
|
||||
Collection<Neuron> exclude = new HashSet<Neuron>();
|
||||
|
||||
// Level-1 neighbourhood.
|
||||
neighbours = net.getNeighbours(net.getNeuron(8));
|
||||
for (long nId : new long[] { 2, 3, 4, 7, 9, 12, 13, 14 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(8, neighbours.size());
|
||||
|
||||
// 1. Add the neuron to the "exclude" list.
|
||||
exclude.add(net.getNeuron(8));
|
||||
// 2. Add all neurons from level-1 neighbourhood.
|
||||
exclude.addAll(neighbours);
|
||||
// 3. Retrieve level-2 neighbourhood.
|
||||
neighbours = net.getNeighbours(neighbours, exclude);
|
||||
for (long nId : new long[] { 1, 6, 11, 16, 17, 18, 19, 15, 10, 5, 0, 20, 24, 23, 22, 21 }) {
|
||||
Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
|
||||
}
|
||||
// Ensures that no other neurons is in the neihbourhood set.
|
||||
Assert.assertEquals(16, neighbours.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize()
|
||||
throws IOException,
|
||||
ClassNotFoundException {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final NeuronSquareMesh2D out = new NeuronSquareMesh2D(4, false,
|
||||
3, true,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray);
|
||||
|
||||
final ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
final ObjectOutputStream oos = new ObjectOutputStream(bos);
|
||||
oos.writeObject(out);
|
||||
|
||||
final ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
|
||||
final ObjectInputStream ois = new ObjectInputStream(bis);
|
||||
final NeuronSquareMesh2D in = (NeuronSquareMesh2D) ois.readObject();
|
||||
|
||||
for (Neuron nOut : out.getNetwork()) {
|
||||
final Neuron nIn = in.getNetwork().getNeuron(nOut.getIdentifier());
|
||||
|
||||
// Same values.
|
||||
final double[] outF = nOut.getFeatures();
|
||||
final double[] inF = nIn.getFeatures();
|
||||
Assert.assertEquals(outF.length, inF.length);
|
||||
for (int i = 0; i < outF.length; i++) {
|
||||
Assert.assertEquals(outF[i], inF[i], 0d);
|
||||
}
|
||||
|
||||
// Same neighbours.
|
||||
final Collection<Neuron> outNeighbours = out.getNetwork().getNeighbours(nOut);
|
||||
final Collection<Neuron> inNeighbours = in.getNetwork().getNeighbours(nIn);
|
||||
Assert.assertEquals(outNeighbours.size(), inNeighbours.size());
|
||||
for (Neuron oN : outNeighbours) {
|
||||
Assert.assertTrue(inNeighbours.contains(in.getNetwork().getNeuron(oN.getIdentifier())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.userguide.sofm;
|
||||
|
||||
import java.util.Iterator;
|
||||
import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
|
||||
import org.apache.commons.math3.geometry.euclidean.threed.Rotation;
|
||||
import org.apache.commons.math3.random.UnitSphereRandomVectorGenerator;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
||||
|
||||
/**
|
||||
* Class that creates two intertwined rings.
|
||||
* Each ring is composed of a cloud of points.
|
||||
*/
|
||||
public class ChineseRings {
|
||||
/** Points in the two rings. */
|
||||
private final Vector3D[] points;
|
||||
|
||||
/**
|
||||
* @param orientationRing1 Vector othogonal to the plane containing the
|
||||
* first ring.
|
||||
* @param radiusRing1 Radius of the first ring.
|
||||
* @param halfWidthRing1 Half-width of the first ring.
|
||||
* @param radiusRing2 Radius of the second ring.
|
||||
* @param halfWidthRing2 Half-width of the second ring.
|
||||
* @param numPointsRing1 Number of points in the first ring.
|
||||
* @param numPointsRing2 Number of points in the second ring.
|
||||
*/
|
||||
public ChineseRings(Vector3D orientationRing1,
|
||||
double radiusRing1,
|
||||
double halfWidthRing1,
|
||||
double radiusRing2,
|
||||
double halfWidthRing2,
|
||||
int numPointsRing1,
|
||||
int numPointsRing2) {
|
||||
// First ring (centered at the origin).
|
||||
final Vector3D[] firstRing = new Vector3D[numPointsRing1];
|
||||
// Second ring (centered around the first ring).
|
||||
final Vector3D[] secondRing = new Vector3D[numPointsRing2];
|
||||
|
||||
// Create two rings lying in xy-plane.
|
||||
final UnitSphereRandomVectorGenerator unit
|
||||
= new UnitSphereRandomVectorGenerator(2);
|
||||
|
||||
final RealDistribution radius1
|
||||
= new UniformRealDistribution(radiusRing1 - halfWidthRing1,
|
||||
radiusRing1 + halfWidthRing1);
|
||||
final RealDistribution widthRing1
|
||||
= new UniformRealDistribution(-halfWidthRing1, halfWidthRing1);
|
||||
|
||||
for (int i = 0; i < numPointsRing1; i++) {
|
||||
final double[] v = unit.nextVector();
|
||||
final double r = radius1.sample();
|
||||
// First ring is in the xy-plane, centered at (0, 0, 0).
|
||||
firstRing[i] = new Vector3D(v[0] * r,
|
||||
v[1] * r,
|
||||
widthRing1.sample());
|
||||
}
|
||||
|
||||
final RealDistribution radius2
|
||||
= new UniformRealDistribution(radiusRing2 - halfWidthRing2,
|
||||
radiusRing2 + halfWidthRing2);
|
||||
final RealDistribution widthRing2
|
||||
= new UniformRealDistribution(-halfWidthRing2, halfWidthRing2);
|
||||
|
||||
for (int i = 0; i < numPointsRing2; i++) {
|
||||
final double[] v = unit.nextVector();
|
||||
final double r = radius2.sample();
|
||||
// Second ring is in the xz-plane, centered at (radiusRing1, 0, 0).
|
||||
secondRing[i] = new Vector3D(radiusRing1 + v[0] * r,
|
||||
widthRing2.sample(),
|
||||
v[1] * r);
|
||||
}
|
||||
|
||||
// Move first and second rings into position.
|
||||
final Rotation rot = new Rotation(Vector3D.PLUS_K,
|
||||
orientationRing1.normalize());
|
||||
int count = 0;
|
||||
points = new Vector3D[numPointsRing1 + numPointsRing2];
|
||||
for (int i = 0; i < numPointsRing1; i++) {
|
||||
points[count++] = rot.applyTo(firstRing[i]);
|
||||
}
|
||||
for (int i = 0; i < numPointsRing2; i++) {
|
||||
points[count++] = rot.applyTo(secondRing[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all the points.
|
||||
*/
|
||||
public Vector3D[] getPoints() {
|
||||
return points.clone();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,335 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math3.userguide.sofm;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.io.PrintWriter;
|
||||
import java.io.IOException;
|
||||
import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
|
||||
import org.apache.commons.math3.ml.neuralnet.MapUtils;
|
||||
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.LearningFactorFunction;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.LearningFactorFunctionFactory;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.NeighbourhoodSizeFunction;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.NeighbourhoodSizeFunctionFactory;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.KohonenUpdateAction;
|
||||
import org.apache.commons.math3.ml.neuralnet.sofm.KohonenTrainingTask;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.Well19937c;
|
||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||
import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
|
||||
|
||||
/**
|
||||
* SOFM for categorizing points that belong to each of two intertwined rings.
|
||||
*
|
||||
* The output currently consists in 3 text files:
|
||||
* <ul>
|
||||
* <li>"before.chinese.U.seq.dat": U-matrix of the SOFM before training</li>
|
||||
* <li>"after.chinese.U.seq.dat": U-matrix of the SOFM after training</li>
|
||||
* <li>"after.chinese.hit.seq.dat": Hit histogram after training</li>
|
||||
* <ul>
|
||||
*/
|
||||
public class ChineseRingsClassifier {
|
||||
/** SOFM. */
|
||||
private final NeuronSquareMesh2D sofm;
|
||||
/** Rings. */
|
||||
private final ChineseRings rings;
|
||||
/** Distance function. */
|
||||
private final DistanceMeasure distance = new EuclideanDistance();
|
||||
|
||||
public static void main(String[] args) {
|
||||
final ChineseRings rings = new ChineseRings(new Vector3D(1, 2, 3),
|
||||
25, 2,
|
||||
20, 1,
|
||||
2000, 1500);
|
||||
final ChineseRingsClassifier classifier = new ChineseRingsClassifier(rings, 15, 15);
|
||||
printU("before.chinese.U.seq.dat", classifier);
|
||||
classifier.createSequentialTask(100000).run();
|
||||
printU("after.chinese.U.seq.dat", classifier);
|
||||
printHit("after.chinese.hit.seq.dat", classifier);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param rings Training data.
|
||||
* @param dim1 Number of rows of the SOFM.
|
||||
* @param dim2 Number of columns of the SOFM.
|
||||
*/
|
||||
public ChineseRingsClassifier(ChineseRings rings,
|
||||
int dim1,
|
||||
int dim2) {
|
||||
this.rings = rings;
|
||||
sofm = new NeuronSquareMesh2D(dim1, false,
|
||||
dim2, false,
|
||||
SquareNeighbourhood.MOORE,
|
||||
makeInitializers());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates training tasks.
|
||||
*
|
||||
* @param numTasks Number of tasks to create.
|
||||
* @param numSamplesPerTask Number of training samples per task.
|
||||
* @return the created tasks.
|
||||
*/
|
||||
public Runnable[] createParallelTasks(int numTasks,
|
||||
long numSamplesPerTask) {
|
||||
final Runnable[] tasks = new Runnable[numTasks];
|
||||
final LearningFactorFunction learning
|
||||
= LearningFactorFunctionFactory.exponentialDecay(1e-1,
|
||||
5e-2,
|
||||
numSamplesPerTask / 2);
|
||||
final double numNeurons = FastMath.sqrt(sofm.getNumberOfRows() * sofm.getNumberOfColumns());
|
||||
final NeighbourhoodSizeFunction neighbourhood
|
||||
= NeighbourhoodSizeFunctionFactory.exponentialDecay(0.5 * numNeurons,
|
||||
0.2 * numNeurons,
|
||||
numSamplesPerTask / 2);
|
||||
|
||||
for (int i = 0; i < numTasks; i++) {
|
||||
final KohonenUpdateAction action = new KohonenUpdateAction(distance,
|
||||
learning,
|
||||
neighbourhood);
|
||||
tasks[i] = new KohonenTrainingTask(sofm.getNetwork(),
|
||||
createRandomIterator(numSamplesPerTask),
|
||||
action);
|
||||
}
|
||||
|
||||
return tasks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a training task.
|
||||
*
|
||||
* @param numSamples Number of training samples.
|
||||
* @return the created task.
|
||||
*/
|
||||
public Runnable createSequentialTask(long numSamples) {
|
||||
return createParallelTasks(1, numSamples)[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the U-matrix.
|
||||
*
|
||||
* @return the U-matrix of the network.
|
||||
*/
|
||||
public double[][] computeU() {
|
||||
return MapUtils.computeU(sofm, distance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the hit histogram.
|
||||
*
|
||||
* @return the histogram.
|
||||
*/
|
||||
public int[][] computeHitHistogram() {
|
||||
return MapUtils.computeHitHistogram(createIterable(),
|
||||
sofm,
|
||||
distance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the quantization error.
|
||||
*
|
||||
* @return the quantization error.
|
||||
*/
|
||||
public double computeQuantizationError() {
|
||||
return MapUtils.computeQuantizationError(createIterable(),
|
||||
sofm.getNetwork(),
|
||||
distance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the topographic error.
|
||||
*
|
||||
* @return the topographic error.
|
||||
*/
|
||||
public double computeTopographicError() {
|
||||
return MapUtils.computeTopographicError(createIterable(),
|
||||
sofm.getNetwork(),
|
||||
distance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the features' initializers.
|
||||
* They are sampled from a uniform distribution around the barycentre of
|
||||
* the rings.
|
||||
*
|
||||
* @return an array containing the initializers for the x, y and
|
||||
* z coordinates of the features array of the neurons.
|
||||
*/
|
||||
private FeatureInitializer[] makeInitializers() {
|
||||
final SummaryStatistics[] centre = new SummaryStatistics[] {
|
||||
new SummaryStatistics(),
|
||||
new SummaryStatistics(),
|
||||
new SummaryStatistics()
|
||||
};
|
||||
for (Vector3D p : rings.getPoints()) {
|
||||
centre[0].addValue(p.getX());
|
||||
centre[1].addValue(p.getY());
|
||||
centre[2].addValue(p.getZ());
|
||||
}
|
||||
|
||||
final double[] mean = new double[] {
|
||||
centre[0].getMean(),
|
||||
centre[1].getMean(),
|
||||
centre[2].getMean()
|
||||
};
|
||||
final double s = 0.1;
|
||||
final double[] dev = new double[] {
|
||||
s * centre[0].getStandardDeviation(),
|
||||
s * centre[1].getStandardDeviation(),
|
||||
s * centre[2].getStandardDeviation()
|
||||
};
|
||||
|
||||
return new FeatureInitializer[] {
|
||||
FeatureInitializerFactory.uniform(mean[0] - dev[0], mean[0] + dev[0]),
|
||||
FeatureInitializerFactory.uniform(mean[1] - dev[1], mean[1] + dev[1]),
|
||||
FeatureInitializerFactory.uniform(mean[2] - dev[2], mean[2] + dev[2])
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an iterable that will present the points coordinates.
|
||||
*
|
||||
* @return the iterable.
|
||||
*/
|
||||
private Iterable<double[]> createIterable() {
|
||||
return new Iterable<double[]>() {
|
||||
public Iterator<double[]> iterator() {
|
||||
return new Iterator<double[]>() {
|
||||
/** Data. */
|
||||
final Vector3D[] points = rings.getPoints();
|
||||
/** Number of samples. */
|
||||
private int n = 0;
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public boolean hasNext() {
|
||||
return n < points.length;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double[] next() {
|
||||
return points[n++].toArray();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public void remove() {
|
||||
throw new MathUnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an iterator that will present a series of points coordinates in
|
||||
* a random order.
|
||||
*
|
||||
* @param numSamples Number of samples.
|
||||
* @return the iterator.
|
||||
*/
|
||||
private Iterator<double[]> createRandomIterator(final long numSamples) {
|
||||
return new Iterator<double[]>() {
|
||||
/** Data. */
|
||||
final Vector3D[] points = rings.getPoints();
|
||||
/** RNG. */
|
||||
final RandomGenerator rng = new Well19937c();
|
||||
/** Number of samples. */
|
||||
private long n = 0;
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public boolean hasNext() {
|
||||
return n < numSamples;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public double[] next() {
|
||||
++n;
|
||||
return points[rng.nextInt(points.length)].toArray();
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public void remove() {
|
||||
throw new MathUnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Prints the U-matrix of the map to the given filename.
|
||||
*
|
||||
* @param filename File.
|
||||
* @param sofm Classifier.
|
||||
*/
|
||||
private static void printU(String filename,
|
||||
ChineseRingsClassifier sofm) {
|
||||
PrintWriter out = null;
|
||||
try {
|
||||
out = new PrintWriter(filename);
|
||||
|
||||
final double[][] uMatrix = sofm.computeU();
|
||||
for (int i = 0; i < uMatrix.length; i++) {
|
||||
for (int j = 0; j < uMatrix[0].length; j++) {
|
||||
out.print(uMatrix[i][j] + " ");
|
||||
}
|
||||
out.println();
|
||||
}
|
||||
out.println("# Quantization error: " + sofm.computeQuantizationError());
|
||||
out.println("# Topographic error: " + sofm.computeTopographicError());
|
||||
} catch (IOException e) {
|
||||
// Do nothing.
|
||||
} finally {
|
||||
if (out != null) {
|
||||
out.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prints the hit histogram of the map to the given filename.
|
||||
*
|
||||
* @param filename File.
|
||||
* @param sofm Classifier.
|
||||
*/
|
||||
private static void printHit(String filename,
|
||||
ChineseRingsClassifier sofm) {
|
||||
PrintWriter out = null;
|
||||
try {
|
||||
out = new PrintWriter(filename);
|
||||
|
||||
final int[][] histo = sofm.computeHitHistogram();
|
||||
for (int i = 0; i < histo.length; i++) {
|
||||
for (int j = 0; j < histo[0].length; j++) {
|
||||
out.print(histo[i][j] + " ");
|
||||
}
|
||||
out.println();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// Do nothing.
|
||||
} finally {
|
||||
if (out != null) {
|
||||
out.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue