MATH-1264

Sort units according to distance from a given vector.
This commit is contained in:
Gilles 2015-09-01 14:25:26 +02:00
parent f189a4c5aa
commit 69b5c82140
3 changed files with 100 additions and 0 deletions

View File

@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties> </properties>
<body> <body>
<release version="3.6" date="XXXX-XX-XX" description=""> <release version="3.6" date="XXXX-XX-XX" description="">
<action dev="erans" type="add" issue="MATH-1264">
"MapUtils" (package "o.a.c.m.ml.neuralnet"): Method to sort units according to distance
from a given vector.
</action>
<action dev="erans" type="add" issue="MATH-1263"> <action dev="erans" type="add" issue="MATH-1263">
Accessor (class "o.a.c.m.ml.neuralnet.twod.NeuronSquareMesh2D"). Accessor (class "o.a.c.m.ml.neuralnet.twod.NeuronSquareMesh2D").
</action> </action>

View File

@ -19,6 +19,9 @@ package org.apache.commons.math3.ml.neuralnet;
import java.util.HashMap; import java.util.HashMap;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.ArrayList;
import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
import org.apache.commons.math3.exception.NoDataException; import org.apache.commons.math3.exception.NoDataException;
@ -103,6 +106,46 @@ public class MapUtils {
return new Pair<Neuron, Neuron>(best[0], best[1]); return new Pair<Neuron, Neuron>(best[0], best[1]);
} }
/**
* Creates a list of neurons sorted in increased order of the distance
* to the given {@code features}.
*
* @param features Data.
* @param neurons List of neurons to scan. If it is empty, an empty array
* will be returned.
* @param distance Distance function.
* @return the neurons, sorted in increasing order of distance in data
* space.
* @throws org.apache.commons.math4.exception.DimensionMismatchException
* if the size of the input is not compatible with the neurons features
* size.
*
* @see #findBest(double[],Iterable,DistanceMeasure)
* @see #findBestAndSecondBest(double[],Iterable,DistanceMeasure)
*
* @since 3.6
*/
public static Neuron[] sort(double[] features,
Iterable<Neuron> neurons,
DistanceMeasure distance) {
final List<PairNeuronDouble> list = new ArrayList<PairNeuronDouble>();
for (final Neuron n : neurons) {
final double d = distance.compute(n.getFeatures(), features);
list.add(new PairNeuronDouble(n, d));
}
Collections.sort(list);
final int len = list.size();
final Neuron[] sorted = new Neuron[len];
for (int i = 0; i < len; i++) {
sorted[i] = list.get(i).getNeuron();
}
return sorted;
}
/** /**
* Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix"> * Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix">
* U-matrix</a> of a two-dimensional map. * U-matrix</a> of a two-dimensional map.
@ -244,4 +287,39 @@ public class MapUtils {
return ((double) notAdjacentCount) / count; return ((double) notAdjacentCount) / count;
} }
/**
* Helper data structure holding a (Neuron, double) pair.
*/
private static class PairNeuronDouble implements Comparable<PairNeuronDouble> {
/** Key */
private final Neuron neuron;
/** Value */
private final double value;
/**
* @param neuron Neuron.
* @param value Value.
*/
public PairNeuronDouble(Neuron neuron,
double value) {
this.neuron = neuron;
this.value = value;
}
/** @return the neuron. */
public Neuron getNeuron() {
return neuron;
}
/** @return the value. */
public double getValue() {
return value;
}
/** {@inheritDoc} */
public int compareTo(PairNeuronDouble other) {
return Double.compare(this.value, other.value);
}
}
} }

View File

@ -88,4 +88,22 @@ public class MapUtilsTest {
Assert.assertEquals(3, allBest.size()); Assert.assertEquals(3, allBest.size());
} }
@Test
public void testSort() {
final Set<Neuron> list = new HashSet<Neuron>();
for (int i = 0; i < 4; i++) {
list.add(new Neuron(i, new double[] { i - 0.5 }));
}
final Neuron[] sorted = MapUtils.sort(new double[] { 3.4 },
list,
new EuclideanDistance());
final long[] expected = new long[] { 3, 2, 1, 0 };
for (int i = 0; i < list.size(); i++) {
Assert.assertEquals(expected[i], sorted[i].getIdentifier());
}
}
} }