MATH-1547: Ranking of any number of the best matching units of a neural network.

This commit is contained in:
Gilles Sadowski 2020-06-26 15:32:02 +02:00
parent 24e2c246ce
commit 960ba5322b
4 changed files with 271 additions and 86 deletions

View File

@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</release>
<release version="4.0" date="XXXX-XX-XX" description="">
<action dev="erans" type="update" issue="MATH-1547">
More flexible ranking of SOFM.
</action>
<action dev="erans" type="fix" issue="MATH-1537" due-to="Jin Xu">
Clean-up (typos and unused "import" statements).
</action>

View File

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.neuralnet;
import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
/**
* Utility for ranking the units (neurons) of a network.
*
* @since 4.0
*/
public class MapRanking {
/** List corresponding to the map passed to the constructor. */
private final List<Neuron> map = new ArrayList<>();
/** Distance function for sorting. */
private final DistanceMeasure distance;
/**
* @param neurons List to be ranked.
* No defensive copy is performed.
* The {@link #rank(double[],int) created list of units} will
* be sorted in increasing order of the {@code distance}.
* @param distance Distance function.
*/
public MapRanking(Iterable<Neuron> neurons,
DistanceMeasure distance) {
this.distance = distance;
for (Neuron n : neurons) {
map.add(n); // No defensive copy.
}
}
/**
* Creates a list of the neurons whose features best correspond to the
* given {@code features}.
*
* @param features Data.
* @return the list of neurons sorted in decreasing order of distance to
* the given data.
* @throws org.apache.commons.math4.exception.DimensionMismatchException
* if the size of the input is not compatible with the neurons features
* size.
*/
public List<Neuron> rank(double[] features) {
return rank(features, map.size());
}
/**
* Creates a list of the neurons whose features best correspond to the
* given {@code features}.
*
* @param features Data.
* @param max Maximum size of the returned list.
* @return the list of neurons sorted in decreasing order of distance to
* the given data.
* @throws org.apache.commons.math4.exception.DimensionMismatchException
* if the size of the input is not compatible with the neurons features
* size.
* @throws NotStrictlyPositiveException if {@code max <= 0}.
*/
public List<Neuron> rank(double[] features,
int max) {
if (max <= 0) {
throw new NotStrictlyPositiveException(max);
}
final int m = max <= map.size() ?
max :
map.size();
final List<PairNeuronDouble> list = new ArrayList<>(m);
for (final Neuron n : map) {
final double d = distance.compute(n.getFeatures(), features);
final PairNeuronDouble p = new PairNeuronDouble(n, d);
if (list.size() < m) {
list.add(p);
if (list.size() > 1) {
// Sort if there is more than 1 element.
Collections.sort(list, PairNeuronDouble.COMPARATOR);
}
} else {
final int last = list.size() - 1;
if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
list.set(last, p); // Replace worst entry.
if (last > 0) {
// Sort if there is more than 1 element.
Collections.sort(list, PairNeuronDouble.COMPARATOR);
}
}
}
}
final List<Neuron> result = new ArrayList<>(m);
for (PairNeuronDouble p : list) {
result.add(p.getNeuron());
}
return result;
}
/**
* Helper data structure holding a (Neuron, double) pair.
*/
private static class PairNeuronDouble {
/** Comparator. */
static final Comparator<PairNeuronDouble> COMPARATOR
= new Comparator<PairNeuronDouble>() {
/** {@inheritDoc} */
@Override
public int compare(PairNeuronDouble o1,
PairNeuronDouble o2) {
return Double.compare(o1.value, o2.value);
}
};
/** Key. */
private final Neuron neuron;
/** Value. */
private final double value;
/**
* @param neuron Neuron.
* @param value Value.
*/
PairNeuronDouble(Neuron neuron, double value) {
this.neuron = neuron;
this.value = value;
}
/** @return the neuron. */
public Neuron getNeuron() {
return neuron;
}
}
}

View File

@ -17,12 +17,9 @@
package org.apache.commons.math4.ml.neuralnet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Comparator;
import org.apache.commons.math4.exception.NoDataException;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
@ -56,17 +53,7 @@ public class MapUtils {
public static Neuron findBest(double[] features,
Iterable<Neuron> neurons,
DistanceMeasure distance) {
Neuron best = null;
double min = Double.POSITIVE_INFINITY;
for (final Neuron n : neurons) {
final double d = distance.compute(n.getFeatures(), features);
if (d < min) {
min = d;
best = n;
}
}
return best;
return new MapRanking(neurons, distance).rank(features, 1).get(0);
}
/**
@ -85,27 +72,8 @@ public class MapUtils {
public static Pair<Neuron, Neuron> findBestAndSecondBest(double[] features,
Iterable<Neuron> neurons,
DistanceMeasure distance) {
Neuron[] best = { null, null };
double[] min = { Double.POSITIVE_INFINITY,
Double.POSITIVE_INFINITY };
for (final Neuron n : neurons) {
final double d = distance.compute(n.getFeatures(), features);
if (d < min[0]) {
// Replace second best with old best.
min[1] = min[0];
best[1] = best[0];
// Store current as new best.
min[0] = d;
best[0] = n;
} else if (d < min[1]) {
// Replace old second best with current.
min[1] = d;
best[1] = n;
}
}
return new Pair<>(best[0], best[1]);
final List<Neuron> list = new MapRanking(neurons, distance).rank(features, 2);
return new Pair<>(list.get(0), list.get(1));
}
/**
@ -130,22 +98,7 @@ public class MapUtils {
public static Neuron[] sort(double[] features,
Iterable<Neuron> neurons,
DistanceMeasure distance) {
final List<PairNeuronDouble> list = new ArrayList<>();
for (final Neuron n : neurons) {
final double d = distance.compute(n.getFeatures(), features);
list.add(new PairNeuronDouble(n, d));
}
Collections.sort(list, PairNeuronDouble.COMPARATOR);
final int len = list.size();
final Neuron[] sorted = new Neuron[len];
for (int i = 0; i < len; i++) {
sorted[i] = list.get(i).getNeuron();
}
return sorted;
return new MapRanking(neurons, distance).rank(features).toArray(new Neuron[0]);
}
/**
@ -289,39 +242,4 @@ public class MapUtils {
return ((double) notAdjacentCount) / count;
}
/**
* Helper data structure holding a (Neuron, double) pair.
*/
private static class PairNeuronDouble {
/** Comparator. */
static final Comparator<PairNeuronDouble> COMPARATOR
= new Comparator<PairNeuronDouble>() {
/** {@inheritDoc} */
@Override
public int compare(PairNeuronDouble o1,
PairNeuronDouble o2) {
return Double.compare(o1.value, o2.value);
}
};
/** Key. */
private final Neuron neuron;
/** Value. */
private final double value;
/**
* @param neuron Neuron.
* @param value Value.
*/
PairNeuronDouble(Neuron neuron, double value) {
this.neuron = neuron;
this.value = value;
}
/** @return the neuron. */
public Neuron getNeuron() {
return neuron;
}
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.neuralnet;
import java.util.Set;
import java.util.HashSet;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.math4.ml.neuralnet.FeatureInitializer;
import org.apache.commons.math4.ml.neuralnet.FeatureInitializerFactory;
import org.apache.commons.math4.ml.neuralnet.MapUtils;
import org.apache.commons.math4.ml.neuralnet.Network;
import org.apache.commons.math4.ml.neuralnet.Neuron;
import org.apache.commons.math4.ml.neuralnet.oned.NeuronString;
import org.junit.Test;
import org.junit.Assert;
/**
* Tests for {@link MapRanking} class.
*/
public class MapRankingTest {
/*
* Test assumes that the network is
*
* 0-----1-----2
*/
@Test
public void testFindClosestNeuron() {
final FeatureInitializer init
= new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(-0.1, 0.1));
final FeatureInitializer[] initArray = { init };
final MapRanking ranking = new MapRanking(new NeuronString(3, false, initArray).getNetwork(),
new EuclideanDistance());
final Set<Neuron> allBest = new HashSet<>();
final Set<Neuron> best = new HashSet<>();
double[][] features;
// The following tests ensures that
// 1. the same neuron is always selected when the input feature is
// in the range of the initializer,
// 2. different network's neuron have been selected by inputs features
// that belong to different ranges.
best.clear();
features = new double[][] {
{ -1 },
{ 0.4 },
};
for (double[] f : features) {
best.addAll(ranking.rank(f, 1));
}
Assert.assertEquals(1, best.size());
allBest.addAll(best);
best.clear();
features = new double[][] {
{ 0.6 },
{ 1.4 },
};
for (double[] f : features) {
best.addAll(ranking.rank(f, 1));
}
Assert.assertEquals(1, best.size());
allBest.addAll(best);
best.clear();
features = new double[][] {
{ 1.6 },
{ 3 },
};
for (double[] f : features) {
best.addAll(ranking.rank(f, 1));
}
Assert.assertEquals(1, best.size());
allBest.addAll(best);
Assert.assertEquals(3, allBest.size());
}
@Test(expected=NotStrictlyPositiveException.class)
public void testRankPrecondition() {
final FeatureInitializer init
= new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(-0.1, 0.1));
final FeatureInitializer[] initArray = { init };
new MapRanking(new NeuronString(3, false, initArray).getNetwork(),
new EuclideanDistance()).rank(new double[] { -1 }, 0);
}
}