MATH-1278

Deep copy of "Neuron", "Network" and "NeuronSquareMesh2D".
This commit is contained in:
Gilles 2015-09-20 22:02:21 +02:00
parent 2fd6c8fa1e
commit 6c4e1d719f
6 changed files with 159 additions and 0 deletions

View File

@ -54,6 +54,10 @@ If the output is not quite correct, check for invisible trailing spaces!
</release>
<release version="4.0" date="XXXX-XX-XX" description="">
<action dev="erans" type="add" issue="MATH-1278"> <!-- backported to 3.6 -->
Deep copy of "Network" (package "o.a.c.m.ml.neuralnet") to allow evaluation of
of intermediate states during training.
</action>
<action dev="oertl" type="update" issue="MATH-1276"> <!-- backported to 3.6 -->
Improved performance of sampling and inverse cumulative probability calculation
for geometric distributions.

View File

@ -28,6 +28,7 @@ import java.util.Collection;
import java.util.Iterator;
import java.util.Comparator;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
@ -136,6 +137,29 @@ public class Network
this.featureSize = featureSize;
}
/**
* Performs a deep copy of this instance.
* Upon return, the copied and original instances will be independent:
* Updating one will not affect the other.
*
* @return a new instance with the same state as this instance.
*/
public synchronized Network copy() {
final Network copy = new Network(nextId.get(),
featureSize);
for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
copy.neuronMap.put(e.getKey(), e.getValue().copy());
}
for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
}
return copy;
}
/**
* {@inheritDoc}
*/

View File

@ -66,6 +66,22 @@ public class Neuron implements Serializable {
this.features = new AtomicReference<double[]>(features.clone());
}
/**
* Performs a deep copy of this instance.
* Upon return, the copied and original instances will be independent:
* Updating one will not affect the other.
*
* @return a new instance with the same state as this instance.
*/
public synchronized Neuron copy() {
final Neuron copy = new Neuron(getIdentifier(),
getFeatures());
copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get());
copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get());
return copy;
}
/**
* Gets the neuron's identifier.
*

View File

@ -198,6 +198,54 @@ public class NeuronSquareMesh2D
createLinks();
}
/**
* Constructor with restricted access, solely used for making a
* {@link #copy() deep copy}.
*
* @param wrapRowDim Whether to wrap the first dimension (i.e the first
* and last neurons will be linked together).
* @param wrapColDim Whether to wrap the second dimension (i.e the first
* and last neurons will be linked together).
* @param neighbourhoodType Neighbourhood type.
* @param net Underlying network.
* @param idGrid Neuron identifiers.
*/
private NeuronSquareMesh2D(boolean wrapRowDim,
boolean wrapColDim,
SquareNeighbourhood neighbourhoodType,
Network net,
long[][] idGrid) {
numberOfRows = idGrid.length;
numberOfColumns = idGrid[0].length;
wrapRows = wrapRowDim;
wrapColumns = wrapColDim;
neighbourhood = neighbourhoodType;
network = net;
identifiers = idGrid;
}
/**
* Performs a deep copy of this instance.
* Upon return, the copied and original instances will be independent:
* Updating one will not affect the other.
*
* @return a new instance with the same state as this instance.
*/
public synchronized NeuronSquareMesh2D copy() {
final long[][] idGrid = new long[numberOfRows][numberOfColumns];
for (int r = 0; r < numberOfRows; r++) {
for (int c = 0; c < numberOfColumns; c++) {
idGrid[r][c] = identifiers[r][c];
}
}
return new NeuronSquareMesh2D(wrapRows,
wrapColumns,
neighbourhood,
network.copy(),
idGrid);
}
/** {@inheritDoc} */
@Override
public Iterator<Neuron> iterator() {

View File

@ -132,6 +132,47 @@ public class NetworkTest {
Assert.assertFalse(isUnspecifiedOrder);
}
/*
* Test assumes that the network is
*
* 0-----1
* | |
* | |
* 2-----3
*/
@Test
public void testCopy() {
final FeatureInitializer[] initArray = { init };
final Network net = new NeuronSquareMesh2D(2, false,
2, false,
SquareNeighbourhood.VON_NEUMANN,
initArray).getNetwork();
final Network copy = net.copy();
final Neuron netNeuron0 = net.getNeuron(0);
final Neuron copyNeuron0 = copy.getNeuron(0);
final Neuron netNeuron1 = net.getNeuron(1);
final Neuron copyNeuron1 = copy.getNeuron(1);
Collection<Neuron> netNeighbours;
Collection<Neuron> copyNeighbours;
// Check that both networks have the same connections.
netNeighbours = net.getNeighbours(netNeuron0);
copyNeighbours = copy.getNeighbours(copyNeuron0);
Assert.assertTrue(netNeighbours.contains(netNeuron1));
Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
// Delete neuron 1 from original.
net.deleteNeuron(netNeuron1);
// Check that the networks now differ.
netNeighbours = net.getNeighbours(netNeuron0);
copyNeighbours = copy.getNeighbours(copyNeuron0);
Assert.assertFalse(netNeighbours.contains(netNeuron1));
Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
}
@Test
public void testSerialize()
throws IOException,

View File

@ -87,6 +87,32 @@ public class NeuronTest {
Assert.assertEquals(update[0], n.getFeatures()[0], 0d);
}
@Test
public void testCopy() {
final Neuron n = new Neuron(1, new double[] { 9.87 });
// Update original.
double[] update = new double[] { n.getFeatures()[0] + 2.34 };
n.compareAndSetFeatures(n.getFeatures(), update);
// Create a copy.
final Neuron copy = n.copy();
// Check that original and copy have the same value.
Assert.assertTrue(n.getFeatures()[0] == copy.getFeatures()[0]);
Assert.assertEquals(n.getNumberOfAttemptedUpdates(),
copy.getNumberOfAttemptedUpdates());
// Update original.
update = new double[] { 1.23 * n.getFeatures()[0] };
n.compareAndSetFeatures(n.getFeatures(), update);
// Check that original and copy differ.
Assert.assertFalse(n.getFeatures()[0] == copy.getFeatures()[0]);
Assert.assertNotEquals(n.getNumberOfSuccessfulUpdates(),
copy.getNumberOfSuccessfulUpdates());
}
@Test
public void testSerialize()
throws IOException,