MATH-1278
Deep copy of "Neuron", "Network" and "NeuronSquareMesh2D".
This commit is contained in:
parent
78b9d819a7
commit
f13693fdc3
|
@ -28,6 +28,7 @@ import java.util.Collection;
|
|||
import java.util.Iterator;
|
||||
import java.util.Comparator;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
|
@ -134,6 +135,29 @@ public class Network
|
|||
this.featureSize = featureSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a deep copy of this instance.
|
||||
* Upon return, the copied and original instances will be independent:
|
||||
* Updating one will not affect the other.
|
||||
*
|
||||
* @return a new instance with the same state as this instance.
|
||||
*/
|
||||
public synchronized Network copy() {
|
||||
final Network copy = new Network(nextId.get(),
|
||||
featureSize);
|
||||
|
||||
|
||||
for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
|
||||
copy.neuronMap.put(e.getKey(), e.getValue().copy());
|
||||
}
|
||||
|
||||
for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
|
||||
copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
|
||||
}
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
|
|
|
@ -66,6 +66,22 @@ public class Neuron implements Serializable {
|
|||
this.features = new AtomicReference<double[]>(features.clone());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a deep copy of this instance.
|
||||
* Upon return, the copied and original instances will be independent:
|
||||
* Updating one will not affect the other.
|
||||
*
|
||||
* @return a new instance with the same state as this instance.
|
||||
*/
|
||||
public synchronized Neuron copy() {
|
||||
final Neuron copy = new Neuron(getIdentifier(),
|
||||
getFeatures());
|
||||
copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get());
|
||||
copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get());
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the neuron's identifier.
|
||||
*
|
||||
|
|
|
@ -197,6 +197,54 @@ public class NeuronSquareMesh2D
|
|||
createLinks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor with restricted access, solely used for making a
|
||||
* {@link #copy() deep copy}.
|
||||
*
|
||||
* @param wrapRowDim Whether to wrap the first dimension (i.e the first
|
||||
* and last neurons will be linked together).
|
||||
* @param wrapColDim Whether to wrap the second dimension (i.e the first
|
||||
* and last neurons will be linked together).
|
||||
* @param neighbourhoodType Neighbourhood type.
|
||||
* @param net Underlying network.
|
||||
* @param idGrid Neuron identifiers.
|
||||
*/
|
||||
private NeuronSquareMesh2D(boolean wrapRowDim,
|
||||
boolean wrapColDim,
|
||||
SquareNeighbourhood neighbourhoodType,
|
||||
Network net,
|
||||
long[][] idGrid) {
|
||||
numberOfRows = idGrid.length;
|
||||
numberOfColumns = idGrid[0].length;
|
||||
wrapRows = wrapRowDim;
|
||||
wrapColumns = wrapColDim;
|
||||
neighbourhood = neighbourhoodType;
|
||||
network = net;
|
||||
identifiers = idGrid;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a deep copy of this instance.
|
||||
* Upon return, the copied and original instances will be independent:
|
||||
* Updating one will not affect the other.
|
||||
*
|
||||
* @return a new instance with the same state as this instance.
|
||||
*/
|
||||
public synchronized NeuronSquareMesh2D copy() {
|
||||
final long[][] idGrid = new long[numberOfRows][numberOfColumns];
|
||||
for (int r = 0; r < numberOfRows; r++) {
|
||||
for (int c = 0; c < numberOfColumns; c++) {
|
||||
idGrid[r][c] = identifiers[r][c];
|
||||
}
|
||||
}
|
||||
|
||||
return new NeuronSquareMesh2D(wrapRows,
|
||||
wrapColumns,
|
||||
neighbourhood,
|
||||
network.copy(),
|
||||
idGrid);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public Iterator<Neuron> iterator() {
|
||||
return network.iterator();
|
||||
|
|
|
@ -127,6 +127,47 @@ public class NetworkTest {
|
|||
Assert.assertFalse(isUnspecifiedOrder);
|
||||
}
|
||||
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1
|
||||
* | |
|
||||
* | |
|
||||
* 2-----3
|
||||
*/
|
||||
@Test
|
||||
public void testCopy() {
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
final Network net = new NeuronSquareMesh2D(2, false,
|
||||
2, false,
|
||||
SquareNeighbourhood.VON_NEUMANN,
|
||||
initArray).getNetwork();
|
||||
|
||||
final Network copy = net.copy();
|
||||
|
||||
final Neuron netNeuron0 = net.getNeuron(0);
|
||||
final Neuron copyNeuron0 = copy.getNeuron(0);
|
||||
final Neuron netNeuron1 = net.getNeuron(1);
|
||||
final Neuron copyNeuron1 = copy.getNeuron(1);
|
||||
Collection<Neuron> netNeighbours;
|
||||
Collection<Neuron> copyNeighbours;
|
||||
|
||||
// Check that both networks have the same connections.
|
||||
netNeighbours = net.getNeighbours(netNeuron0);
|
||||
copyNeighbours = copy.getNeighbours(copyNeuron0);
|
||||
Assert.assertTrue(netNeighbours.contains(netNeuron1));
|
||||
Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
|
||||
|
||||
// Delete neuron 1 from original.
|
||||
net.deleteNeuron(netNeuron1);
|
||||
|
||||
// Check that the networks now differ.
|
||||
netNeighbours = net.getNeighbours(netNeuron0);
|
||||
copyNeighbours = copy.getNeighbours(copyNeuron0);
|
||||
Assert.assertFalse(netNeighbours.contains(netNeuron1));
|
||||
Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize()
|
||||
throws IOException,
|
||||
|
|
|
@ -85,6 +85,32 @@ public class NeuronTest {
|
|||
Assert.assertEquals(update[0], n.getFeatures()[0], 0d);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopy() {
|
||||
final Neuron n = new Neuron(1, new double[] { 9.87 });
|
||||
|
||||
// Update original.
|
||||
double[] update = new double[] { n.getFeatures()[0] + 2.34 };
|
||||
n.compareAndSetFeatures(n.getFeatures(), update);
|
||||
|
||||
// Create a copy.
|
||||
final Neuron copy = n.copy();
|
||||
|
||||
// Check that original and copy have the same value.
|
||||
Assert.assertTrue(n.getFeatures()[0] == copy.getFeatures()[0]);
|
||||
Assert.assertEquals(n.getNumberOfAttemptedUpdates(),
|
||||
copy.getNumberOfAttemptedUpdates());
|
||||
|
||||
// Update original.
|
||||
update = new double[] { 1.23 * n.getFeatures()[0] };
|
||||
n.compareAndSetFeatures(n.getFeatures(), update);
|
||||
|
||||
// Check that original and copy differ.
|
||||
Assert.assertFalse(n.getFeatures()[0] == copy.getFeatures()[0]);
|
||||
Assert.assertNotEquals(n.getNumberOfSuccessfulUpdates(),
|
||||
copy.getNumberOfSuccessfulUpdates());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize()
|
||||
throws IOException,
|
||||
|
|
Loading…
Reference in New Issue