MATH-1264
Sort units according to distance from a given vector.
This commit is contained in:
parent
f189a4c5aa
commit
69b5c82140
|
@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</properties>
|
||||
<body>
|
||||
<release version="3.6" date="XXXX-XX-XX" description="">
|
||||
<action dev="erans" type="add" issue="MATH-1264">
|
||||
"MapUtils" (package "o.a.c.m.ml.neuralnet"): Method to sort units according to distance
|
||||
from a given vector.
|
||||
</action>
|
||||
<action dev="erans" type="add" issue="MATH-1263">
|
||||
Accessor (class "o.a.c.m.ml.neuralnet.twod.NeuronSquareMesh2D").
|
||||
</action>
|
||||
|
|
|
@ -19,6 +19,9 @@ package org.apache.commons.math3.ml.neuralnet;
|
|||
|
||||
import java.util.HashMap;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
|
||||
import org.apache.commons.math3.exception.NoDataException;
|
||||
|
@ -103,6 +106,46 @@ public class MapUtils {
|
|||
return new Pair<Neuron, Neuron>(best[0], best[1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of neurons sorted in increased order of the distance
|
||||
* to the given {@code features}.
|
||||
*
|
||||
* @param features Data.
|
||||
* @param neurons List of neurons to scan. If it is empty, an empty array
|
||||
* will be returned.
|
||||
* @param distance Distance function.
|
||||
* @return the neurons, sorted in increasing order of distance in data
|
||||
* space.
|
||||
* @throws org.apache.commons.math4.exception.DimensionMismatchException
|
||||
* if the size of the input is not compatible with the neurons features
|
||||
* size.
|
||||
*
|
||||
* @see #findBest(double[],Iterable,DistanceMeasure)
|
||||
* @see #findBestAndSecondBest(double[],Iterable,DistanceMeasure)
|
||||
*
|
||||
* @since 3.6
|
||||
*/
|
||||
public static Neuron[] sort(double[] features,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
final List<PairNeuronDouble> list = new ArrayList<PairNeuronDouble>();
|
||||
|
||||
for (final Neuron n : neurons) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
list.add(new PairNeuronDouble(n, d));
|
||||
}
|
||||
|
||||
Collections.sort(list);
|
||||
|
||||
final int len = list.size();
|
||||
final Neuron[] sorted = new Neuron[len];
|
||||
|
||||
for (int i = 0; i < len; i++) {
|
||||
sorted[i] = list.get(i).getNeuron();
|
||||
}
|
||||
return sorted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix">
|
||||
* U-matrix</a> of a two-dimensional map.
|
||||
|
@ -244,4 +287,39 @@ public class MapUtils {
|
|||
|
||||
return ((double) notAdjacentCount) / count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper data structure holding a (Neuron, double) pair.
|
||||
*/
|
||||
private static class PairNeuronDouble implements Comparable<PairNeuronDouble> {
|
||||
/** Key */
|
||||
private final Neuron neuron;
|
||||
/** Value */
|
||||
private final double value;
|
||||
|
||||
/**
|
||||
* @param neuron Neuron.
|
||||
* @param value Value.
|
||||
*/
|
||||
public PairNeuronDouble(Neuron neuron,
|
||||
double value) {
|
||||
this.neuron = neuron;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/** @return the neuron. */
|
||||
public Neuron getNeuron() {
|
||||
return neuron;
|
||||
}
|
||||
|
||||
/** @return the value. */
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
public int compareTo(PairNeuronDouble other) {
|
||||
return Double.compare(this.value, other.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,4 +88,22 @@ public class MapUtilsTest {
|
|||
|
||||
Assert.assertEquals(3, allBest.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSort() {
|
||||
final Set<Neuron> list = new HashSet<Neuron>();
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
list.add(new Neuron(i, new double[] { i - 0.5 }));
|
||||
}
|
||||
|
||||
final Neuron[] sorted = MapUtils.sort(new double[] { 3.4 },
|
||||
list,
|
||||
new EuclideanDistance());
|
||||
|
||||
final long[] expected = new long[] { 3, 2, 1, 0 };
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
Assert.assertEquals(expected[i], sorted[i].getIdentifier());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue