MATH-1547: Ranking of any number of the best matching units of a neural network.
This commit is contained in:
parent
24e2c246ce
commit
960ba5322b
|
@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
|
|||
</release>
|
||||
|
||||
<release version="4.0" date="XXXX-XX-XX" description="">
|
||||
<action dev="erans" type="update" issue="MATH-1547">
|
||||
More flexible ranking of SOFM.
|
||||
</action>
|
||||
<action dev="erans" type="fix" issue="MATH-1537" due-to="Jin Xu">
|
||||
Clean-up (typos and unused "import" statements).
|
||||
</action>
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math4.ml.neuralnet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
|
||||
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||
|
||||
/**
|
||||
* Utility for ranking the units (neurons) of a network.
|
||||
*
|
||||
* @since 4.0
|
||||
*/
|
||||
public class MapRanking {
|
||||
/** List corresponding to the map passed to the constructor. */
|
||||
private final List<Neuron> map = new ArrayList<>();
|
||||
/** Distance function for sorting. */
|
||||
private final DistanceMeasure distance;
|
||||
|
||||
/**
|
||||
* @param neurons List to be ranked.
|
||||
* No defensive copy is performed.
|
||||
* The {@link #rank(double[],int) created list of units} will
|
||||
* be sorted in increasing order of the {@code distance}.
|
||||
* @param distance Distance function.
|
||||
*/
|
||||
public MapRanking(Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
this.distance = distance;
|
||||
|
||||
for (Neuron n : neurons) {
|
||||
map.add(n); // No defensive copy.
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of the neurons whose features best correspond to the
|
||||
* given {@code features}.
|
||||
*
|
||||
* @param features Data.
|
||||
* @return the list of neurons sorted in decreasing order of distance to
|
||||
* the given data.
|
||||
* @throws org.apache.commons.math4.exception.DimensionMismatchException
|
||||
* if the size of the input is not compatible with the neurons features
|
||||
* size.
|
||||
*/
|
||||
public List<Neuron> rank(double[] features) {
|
||||
return rank(features, map.size());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of the neurons whose features best correspond to the
|
||||
* given {@code features}.
|
||||
*
|
||||
* @param features Data.
|
||||
* @param max Maximum size of the returned list.
|
||||
* @return the list of neurons sorted in decreasing order of distance to
|
||||
* the given data.
|
||||
* @throws org.apache.commons.math4.exception.DimensionMismatchException
|
||||
* if the size of the input is not compatible with the neurons features
|
||||
* size.
|
||||
* @throws NotStrictlyPositiveException if {@code max <= 0}.
|
||||
*/
|
||||
public List<Neuron> rank(double[] features,
|
||||
int max) {
|
||||
if (max <= 0) {
|
||||
throw new NotStrictlyPositiveException(max);
|
||||
}
|
||||
final int m = max <= map.size() ?
|
||||
max :
|
||||
map.size();
|
||||
final List<PairNeuronDouble> list = new ArrayList<>(m);
|
||||
|
||||
for (final Neuron n : map) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
final PairNeuronDouble p = new PairNeuronDouble(n, d);
|
||||
|
||||
if (list.size() < m) {
|
||||
list.add(p);
|
||||
if (list.size() > 1) {
|
||||
// Sort if there is more than 1 element.
|
||||
Collections.sort(list, PairNeuronDouble.COMPARATOR);
|
||||
}
|
||||
} else {
|
||||
final int last = list.size() - 1;
|
||||
if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
|
||||
list.set(last, p); // Replace worst entry.
|
||||
if (last > 0) {
|
||||
// Sort if there is more than 1 element.
|
||||
Collections.sort(list, PairNeuronDouble.COMPARATOR);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final List<Neuron> result = new ArrayList<>(m);
|
||||
for (PairNeuronDouble p : list) {
|
||||
result.add(p.getNeuron());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper data structure holding a (Neuron, double) pair.
|
||||
*/
|
||||
private static class PairNeuronDouble {
|
||||
/** Comparator. */
|
||||
static final Comparator<PairNeuronDouble> COMPARATOR
|
||||
= new Comparator<PairNeuronDouble>() {
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int compare(PairNeuronDouble o1,
|
||||
PairNeuronDouble o2) {
|
||||
return Double.compare(o1.value, o2.value);
|
||||
}
|
||||
};
|
||||
/** Key. */
|
||||
private final Neuron neuron;
|
||||
/** Value. */
|
||||
private final double value;
|
||||
|
||||
/**
|
||||
* @param neuron Neuron.
|
||||
* @param value Value.
|
||||
*/
|
||||
PairNeuronDouble(Neuron neuron, double value) {
|
||||
this.neuron = neuron;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/** @return the neuron. */
|
||||
public Neuron getNeuron() {
|
||||
return neuron;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,12 +17,9 @@
|
|||
|
||||
package org.apache.commons.math4.ml.neuralnet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Comparator;
|
||||
|
||||
import org.apache.commons.math4.exception.NoDataException;
|
||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||
|
@ -56,17 +53,7 @@ public class MapUtils {
|
|||
public static Neuron findBest(double[] features,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
Neuron best = null;
|
||||
double min = Double.POSITIVE_INFINITY;
|
||||
for (final Neuron n : neurons) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
if (d < min) {
|
||||
min = d;
|
||||
best = n;
|
||||
}
|
||||
}
|
||||
|
||||
return best;
|
||||
return new MapRanking(neurons, distance).rank(features, 1).get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -85,27 +72,8 @@ public class MapUtils {
|
|||
public static Pair<Neuron, Neuron> findBestAndSecondBest(double[] features,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
Neuron[] best = { null, null };
|
||||
double[] min = { Double.POSITIVE_INFINITY,
|
||||
Double.POSITIVE_INFINITY };
|
||||
for (final Neuron n : neurons) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
if (d < min[0]) {
|
||||
// Replace second best with old best.
|
||||
min[1] = min[0];
|
||||
best[1] = best[0];
|
||||
|
||||
// Store current as new best.
|
||||
min[0] = d;
|
||||
best[0] = n;
|
||||
} else if (d < min[1]) {
|
||||
// Replace old second best with current.
|
||||
min[1] = d;
|
||||
best[1] = n;
|
||||
}
|
||||
}
|
||||
|
||||
return new Pair<>(best[0], best[1]);
|
||||
final List<Neuron> list = new MapRanking(neurons, distance).rank(features, 2);
|
||||
return new Pair<>(list.get(0), list.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -130,22 +98,7 @@ public class MapUtils {
|
|||
public static Neuron[] sort(double[] features,
|
||||
Iterable<Neuron> neurons,
|
||||
DistanceMeasure distance) {
|
||||
final List<PairNeuronDouble> list = new ArrayList<>();
|
||||
|
||||
for (final Neuron n : neurons) {
|
||||
final double d = distance.compute(n.getFeatures(), features);
|
||||
list.add(new PairNeuronDouble(n, d));
|
||||
}
|
||||
|
||||
Collections.sort(list, PairNeuronDouble.COMPARATOR);
|
||||
|
||||
final int len = list.size();
|
||||
final Neuron[] sorted = new Neuron[len];
|
||||
|
||||
for (int i = 0; i < len; i++) {
|
||||
sorted[i] = list.get(i).getNeuron();
|
||||
}
|
||||
return sorted;
|
||||
return new MapRanking(neurons, distance).rank(features).toArray(new Neuron[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -289,39 +242,4 @@ public class MapUtils {
|
|||
|
||||
return ((double) notAdjacentCount) / count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper data structure holding a (Neuron, double) pair.
|
||||
*/
|
||||
private static class PairNeuronDouble {
|
||||
/** Comparator. */
|
||||
static final Comparator<PairNeuronDouble> COMPARATOR
|
||||
= new Comparator<PairNeuronDouble>() {
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public int compare(PairNeuronDouble o1,
|
||||
PairNeuronDouble o2) {
|
||||
return Double.compare(o1.value, o2.value);
|
||||
}
|
||||
};
|
||||
/** Key. */
|
||||
private final Neuron neuron;
|
||||
/** Value. */
|
||||
private final double value;
|
||||
|
||||
/**
|
||||
* @param neuron Neuron.
|
||||
* @param value Value.
|
||||
*/
|
||||
PairNeuronDouble(Neuron neuron, double value) {
|
||||
this.neuron = neuron;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/** @return the neuron. */
|
||||
public Neuron getNeuron() {
|
||||
return neuron;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.commons.math4.ml.neuralnet;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
|
||||
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
|
||||
import org.apache.commons.math4.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math4.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math4.ml.neuralnet.FeatureInitializer;
|
||||
import org.apache.commons.math4.ml.neuralnet.FeatureInitializerFactory;
|
||||
import org.apache.commons.math4.ml.neuralnet.MapUtils;
|
||||
import org.apache.commons.math4.ml.neuralnet.Network;
|
||||
import org.apache.commons.math4.ml.neuralnet.Neuron;
|
||||
import org.apache.commons.math4.ml.neuralnet.oned.NeuronString;
|
||||
import org.junit.Test;
|
||||
import org.junit.Assert;
|
||||
|
||||
/**
|
||||
* Tests for {@link MapRanking} class.
|
||||
*/
|
||||
public class MapRankingTest {
|
||||
/*
|
||||
* Test assumes that the network is
|
||||
*
|
||||
* 0-----1-----2
|
||||
*/
|
||||
@Test
|
||||
public void testFindClosestNeuron() {
|
||||
final FeatureInitializer init
|
||||
= new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(-0.1, 0.1));
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
|
||||
final MapRanking ranking = new MapRanking(new NeuronString(3, false, initArray).getNetwork(),
|
||||
new EuclideanDistance());
|
||||
|
||||
final Set<Neuron> allBest = new HashSet<>();
|
||||
final Set<Neuron> best = new HashSet<>();
|
||||
double[][] features;
|
||||
|
||||
// The following tests ensures that
|
||||
// 1. the same neuron is always selected when the input feature is
|
||||
// in the range of the initializer,
|
||||
// 2. different network's neuron have been selected by inputs features
|
||||
// that belong to different ranges.
|
||||
|
||||
best.clear();
|
||||
features = new double[][] {
|
||||
{ -1 },
|
||||
{ 0.4 },
|
||||
};
|
||||
for (double[] f : features) {
|
||||
best.addAll(ranking.rank(f, 1));
|
||||
}
|
||||
Assert.assertEquals(1, best.size());
|
||||
allBest.addAll(best);
|
||||
|
||||
best.clear();
|
||||
features = new double[][] {
|
||||
{ 0.6 },
|
||||
{ 1.4 },
|
||||
};
|
||||
for (double[] f : features) {
|
||||
best.addAll(ranking.rank(f, 1));
|
||||
}
|
||||
Assert.assertEquals(1, best.size());
|
||||
allBest.addAll(best);
|
||||
|
||||
best.clear();
|
||||
features = new double[][] {
|
||||
{ 1.6 },
|
||||
{ 3 },
|
||||
};
|
||||
for (double[] f : features) {
|
||||
best.addAll(ranking.rank(f, 1));
|
||||
}
|
||||
Assert.assertEquals(1, best.size());
|
||||
allBest.addAll(best);
|
||||
|
||||
Assert.assertEquals(3, allBest.size());
|
||||
}
|
||||
|
||||
@Test(expected=NotStrictlyPositiveException.class)
|
||||
public void testRankPrecondition() {
|
||||
final FeatureInitializer init
|
||||
= new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(-0.1, 0.1));
|
||||
final FeatureInitializer[] initArray = { init };
|
||||
|
||||
new MapRanking(new NeuronString(3, false, initArray).getNetwork(),
|
||||
new EuclideanDistance()).rank(new double[] { -1 }, 0);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue