From 69b5c82140d95c2d584b98c14d15f1066d55f187 Mon Sep 17 00:00:00 2001 From: Gilles Date: Tue, 1 Sep 2015 14:25:26 +0200 Subject: [PATCH] MATH-1264 Sort units according to distance from a given vector. --- src/changes/changes.xml | 4 + .../commons/math3/ml/neuralnet/MapUtils.java | 78 +++++++++++++++++++ .../math3/ml/neuralnet/MapUtilsTest.java | 18 +++++ 3 files changed, 100 insertions(+) diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 29ce88fd1..88594b551 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces! + + "MapUtils" (package "o.a.c.m.ml.neuralnet"): Method to sort units according to distance + from a given vector. + Accessor (class "o.a.c.m.ml.neuralnet.twod.NeuronSquareMesh2D"). diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java index 9e67982df..e7cf598c2 100644 --- a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java @@ -19,6 +19,9 @@ package org.apache.commons.math3.ml.neuralnet; import java.util.HashMap; import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.ArrayList; import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; import org.apache.commons.math3.exception.NoDataException; @@ -103,6 +106,46 @@ public class MapUtils { return new Pair(best[0], best[1]); } + /** + * Creates a list of neurons sorted in increased order of the distance + * to the given {@code features}. + * + * @param features Data. + * @param neurons List of neurons to scan. If it is empty, an empty array + * will be returned. + * @param distance Distance function. + * @return the neurons, sorted in increasing order of distance in data + * space. + * @throws org.apache.commons.math4.exception.DimensionMismatchException + * if the size of the input is not compatible with the neurons features + * size. + * + * @see #findBest(double[],Iterable,DistanceMeasure) + * @see #findBestAndSecondBest(double[],Iterable,DistanceMeasure) + * + * @since 3.6 + */ + public static Neuron[] sort(double[] features, + Iterable neurons, + DistanceMeasure distance) { + final List list = new ArrayList(); + + for (final Neuron n : neurons) { + final double d = distance.compute(n.getFeatures(), features); + list.add(new PairNeuronDouble(n, d)); + } + + Collections.sort(list); + + final int len = list.size(); + final Neuron[] sorted = new Neuron[len]; + + for (int i = 0; i < len; i++) { + sorted[i] = list.get(i).getNeuron(); + } + return sorted; + } + /** * Computes the * U-matrix of a two-dimensional map. @@ -244,4 +287,39 @@ public class MapUtils { return ((double) notAdjacentCount) / count; } + + /** + * Helper data structure holding a (Neuron, double) pair. + */ + private static class PairNeuronDouble implements Comparable { + /** Key */ + private final Neuron neuron; + /** Value */ + private final double value; + + /** + * @param neuron Neuron. + * @param value Value. + */ + public PairNeuronDouble(Neuron neuron, + double value) { + this.neuron = neuron; + this.value = value; + } + + /** @return the neuron. */ + public Neuron getNeuron() { + return neuron; + } + + /** @return the value. */ + public double getValue() { + return value; + } + + /** {@inheritDoc} */ + public int compareTo(PairNeuronDouble other) { + return Double.compare(this.value, other.value); + } + } } diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java index 72bf09cd0..b6216c155 100644 --- a/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java +++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java @@ -88,4 +88,22 @@ public class MapUtilsTest { Assert.assertEquals(3, allBest.size()); } + + @Test + public void testSort() { + final Set list = new HashSet(); + + for (int i = 0; i < 4; i++) { + list.add(new Neuron(i, new double[] { i - 0.5 })); + } + + final Neuron[] sorted = MapUtils.sort(new double[] { 3.4 }, + list, + new EuclideanDistance()); + + final long[] expected = new long[] { 3, 2, 1, 0 }; + for (int i = 0; i < list.size(); i++) { + Assert.assertEquals(expected[i], sorted[i].getIdentifier()); + } + } }