MATH-1267

Helper for finding the grid coordinates of a "Neuron" in a "NeuronSquareMesh2D".
This commit is contained in:
Gilles 2015-09-08 14:26:42 +02:00
parent 50c5eae1a6
commit a382b33abb
2 changed files with 174 additions and 0 deletions

View File

@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.ml.neuralnet.twod.util;
import java.util.Map;
import java.util.HashMap;
import org.apache.commons.math3.ml.neuralnet.Neuron;
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
import org.apache.commons.math3.exception.MathIllegalStateException;
/**
* Helper class to find the grid coordinates of a neuron.
*/
public class LocationFinder {
/** Identifier to location mapping. */
private final Map<Long, Location> locations = new HashMap<Long, Location>();
/**
* Container holding a (row, column) pair.
*/
public static class Location {
/** Row index. */
private final int row;
/** Column index. */
private final int column;
/**
* @param row Row index.
* @param column Column index.
*/
public Location(int row,
int column) {
this.row = row;
this.column = column;
}
/**
* @return the row index.
*/
public int getRow() {
return row;
}
/**
* @return the column index.
*/
public int getColumn() {
return column;
}
}
/**
* Builds a finder to retrieve the locations of neurons that
* belong to the given {@code map}.
*
* @param map Map.
*
* @throws MathIllegalStateException if the network contains non-unique
* identifiers. This indicates an inconsistent state due to a bug in
* the construction code of the underlying
* {@link org.apache.commons.math3.ml.neuralnet.Network network}.
*/
public LocationFinder(NeuronSquareMesh2D map) {
final int nR = map.getNumberOfRows();
final int nC = map.getNumberOfColumns();
for (int r = 0; r < nR; r++) {
for (int c = 0; c < nC; c++) {
final Long id = map.getNeuron(r, c).getIdentifier();
if (locations.get(id) != null) {
throw new MathIllegalStateException();
}
locations.put(id, new Location(r, c));
}
}
}
/**
* Retrieves a neuron's grid coordinates.
*
* @param n Neuron.
* @return the (row, column) coordinates of {@code n}, or {@code null}
* if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D)
* map used to build this instance}.
*/
public Location getLocation(Neuron n) {
return locations.get(n.getIdentifier());
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.ml.neuralnet.twod.util;
import org.apache.commons.math3.ml.neuralnet.Neuron;
import org.apache.commons.math3.ml.neuralnet.Network;
import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
import org.apache.commons.math3.ml.neuralnet.FeatureInitializerFactory;
import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
import org.junit.Assert;
import org.junit.Test;
/**
* Test for {@link LocationFinder}.
*/
public class LocationFinderTest {
final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
/*
* Test assumes that the network is
*
* 0-----1
* | |
* | |
* 2-----3
*/
@Test
public void test2x2Network() {
final FeatureInitializer[] initArray = { init };
final NeuronSquareMesh2D map = new NeuronSquareMesh2D(2, false,
2, false,
SquareNeighbourhood.VON_NEUMANN,
initArray);
final LocationFinder finder = new LocationFinder(map);
final Network net = map.getNetwork();
LocationFinder.Location loc;
loc = finder.getLocation(net.getNeuron(0));
Assert.assertEquals(0, loc.getRow());
Assert.assertEquals(0, loc.getColumn());
loc = finder.getLocation(net.getNeuron(1));
Assert.assertEquals(0, loc.getRow());
Assert.assertEquals(1, loc.getColumn());
loc = finder.getLocation(net.getNeuron(2));
Assert.assertEquals(1, loc.getRow());
Assert.assertEquals(0, loc.getColumn());
loc = finder.getLocation(net.getNeuron(3));
Assert.assertEquals(1, loc.getRow());
Assert.assertEquals(1, loc.getColumn());
}
}