LUCENE-3501: random sampler was not random (and so facet SamplingWrapperTest occasionally failed)

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1181760 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Doron Cohen 2011-10-11 12:54:45 +00:00
parent b438b265aa
commit bd067ee329
6 changed files with 223 additions and 365 deletions

View File

@ -2,12 +2,15 @@ package org.apache.lucene.facet.search;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Random;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.facet.search.params.FacetSearchParams; import org.apache.lucene.facet.search.params.FacetSearchParams;
import org.apache.lucene.facet.search.results.FacetResult; import org.apache.lucene.facet.search.results.FacetResult;
import org.apache.lucene.facet.search.results.FacetResultNode; import org.apache.lucene.facet.search.results.FacetResultNode;
import org.apache.lucene.facet.search.sampling.RandomSampler;
import org.apache.lucene.facet.search.sampling.RepeatableSampler;
import org.apache.lucene.facet.search.sampling.Sampler; import org.apache.lucene.facet.search.sampling.Sampler;
import org.apache.lucene.facet.search.sampling.SamplingAccumulator; import org.apache.lucene.facet.search.sampling.SamplingAccumulator;
import org.apache.lucene.facet.taxonomy.TaxonomyReader; import org.apache.lucene.facet.taxonomy.TaxonomyReader;
@ -44,7 +47,7 @@ import org.apache.lucene.facet.taxonomy.TaxonomyReader;
*/ */
public final class AdaptiveFacetsAccumulator extends StandardFacetsAccumulator { public final class AdaptiveFacetsAccumulator extends StandardFacetsAccumulator {
private Sampler sampler = new Sampler(); private Sampler sampler = new RandomSampler();
/** /**
* Create an {@link AdaptiveFacetsAccumulator} * Create an {@link AdaptiveFacetsAccumulator}

View File

@ -0,0 +1,71 @@
package org.apache.lucene.facet.search.sampling;
import java.io.IOException;
import java.util.Random;
import org.apache.lucene.facet.search.ScoredDocIDs;
import org.apache.lucene.facet.search.ScoredDocIDsIterator;
import org.apache.lucene.facet.util.ScoredDocIdsUtils;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Simple random sampler
*/
public class RandomSampler extends Sampler {
private final Random random;
public RandomSampler() {
super();
this.random = new Random();
}
public RandomSampler(SamplingParams params, Random random) throws IllegalArgumentException {
super(params);
this.random = random;
}
@Override
protected SampleResult createSample(ScoredDocIDs docids, int actualSize, int sampleSetSize) throws IOException {
final int[] sample = new int[sampleSetSize];
final int maxStep = (actualSize * 2 ) / sampleSetSize; //floor
int remaining = actualSize;
ScoredDocIDsIterator it = docids.iterator();
int i = 0;
// select sample docs with random skipStep, make sure to leave sufficient #docs for selection after last skip
while (i<sample.length && remaining>(sampleSetSize-maxStep-i)) {
int skipStep = 1 + random.nextInt(maxStep);
// Skip over 'skipStep' documents
for (int j=0; j<skipStep; j++) {
it.next();
-- remaining;
}
sample[i++] = it.getDocID();
}
// Add leftover documents to the sample set
while (i<sample.length) {
it.next();
sample[i++] = it.getDocID();
}
ScoredDocIDs sampleRes = ScoredDocIdsUtils.createScoredDocIDsSubset(docids, sample);
SampleResult res = new SampleResult(sampleRes, sampleSetSize/(double)actualSize);
return res;
}
}

View File

@ -1,25 +1,15 @@
package org.apache.lucene.facet.util; package org.apache.lucene.facet.search.sampling;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.apache.lucene.analysis.core.KeywordAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.LockObtainFailedException;
import org.apache.lucene.store.RAMDirectory;
import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.Version;
import org.apache.lucene.facet.search.ScoredDocIDs; import org.apache.lucene.facet.search.ScoredDocIDs;
import org.apache.lucene.facet.search.ScoredDocIDsIterator; import org.apache.lucene.facet.search.ScoredDocIDsIterator;
import org.apache.lucene.facet.util.ScoredDocIdsUtils;
/** /**
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
@ -40,12 +30,37 @@ import org.apache.lucene.facet.search.ScoredDocIDsIterator;
/** /**
* Take random samples of large collections. * Take random samples of large collections.
*
* @lucene.experimental * @lucene.experimental
*/ */
public class RandomSample { public class RepeatableSampler extends Sampler {
private static final Logger logger = Logger.getLogger(RandomSample.class.getName()); private static final Logger logger = Logger.getLogger(RepeatableSampler.class.getName());
public RepeatableSampler(SamplingParams params) {
super(params);
}
@Override
protected SampleResult createSample(ScoredDocIDs docids, int actualSize,
int sampleSetSize) throws IOException {
int[] sampleSet = null;
try {
sampleSet = repeatableSample(docids, actualSize,
sampleSetSize);
} catch (IOException e) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "sampling failed: "+e.getMessage()+" - falling back to no sampling!", e);
}
return new SampleResult(docids, 1d);
}
ScoredDocIDs sampled = ScoredDocIdsUtils.createScoredDocIDsSubset(docids,
sampleSet);
if (logger.isLoggable(Level.FINEST)) {
logger.finest("******************** " + sampled.size());
}
return new SampleResult(sampled, sampled.size()/(double)docids.size());
}
/** /**
* Returns <code>sampleSize</code> values from the first <code>collectionSize</code> * Returns <code>sampleSize</code> values from the first <code>collectionSize</code>
@ -57,10 +72,10 @@ public class RandomSample {
* @return An array of values chosen from the collection. * @return An array of values chosen from the collection.
* @see Algorithm#TRAVERSAL * @see Algorithm#TRAVERSAL
*/ */
public static int[] repeatableSample(ScoredDocIDs collection, private static int[] repeatableSample(ScoredDocIDs collection,
int collectionSize, int sampleSize) int collectionSize, int sampleSize)
throws IOException { throws IOException {
return RandomSample.repeatableSample(collection, collectionSize, return repeatableSample(collection, collectionSize,
sampleSize, Algorithm.HASHING, Sorted.NO); sampleSize, Algorithm.HASHING, Sorted.NO);
} }
@ -75,7 +90,7 @@ public class RandomSample {
* Sorted.NO to return them in essentially random order. * Sorted.NO to return them in essentially random order.
* @return An array of values chosen from the collection. * @return An array of values chosen from the collection.
*/ */
public static int[] repeatableSample(ScoredDocIDs collection, private static int[] repeatableSample(ScoredDocIDs collection,
int collectionSize, int sampleSize, int collectionSize, int sampleSize,
Algorithm algorithm, Sorted sorted) Algorithm algorithm, Sorted sorted)
throws IOException { throws IOException {
@ -91,16 +106,16 @@ public class RandomSample {
int[] sample = new int[sampleSize]; int[] sample = new int[sampleSize];
long[] times = new long[4]; long[] times = new long[4];
if (algorithm == Algorithm.TRAVERSAL) { if (algorithm == Algorithm.TRAVERSAL) {
RandomSample.sample1(collection, collectionSize, sample, times); sample1(collection, collectionSize, sample, times);
} else if (algorithm == Algorithm.HASHING) { } else if (algorithm == Algorithm.HASHING) {
RandomSample.sample2(collection, collectionSize, sample, times); sample2(collection, collectionSize, sample, times);
} else { } else {
throw new IllegalArgumentException("Invalid algorithm selection"); throw new IllegalArgumentException("Invalid algorithm selection");
} }
if (sorted == Sorted.YES) { if (sorted == Sorted.YES) {
Arrays.sort(sample); Arrays.sort(sample);
} }
if (RandomSample.returnTimings) { if (returnTimings) {
times[3] = System.currentTimeMillis(); times[3] = System.currentTimeMillis();
if (logger.isLoggable(Level.FINEST)) { if (logger.isLoggable(Level.FINEST)) {
logger.finest("Times: " + (times[1] - times[0]) + "ms, " logger.finest("Times: " + (times[1] - times[0]) + "ms, "
@ -133,13 +148,13 @@ public class RandomSample {
private static void sample1(ScoredDocIDs collection, int collectionSize, int[] sample, long[] times) private static void sample1(ScoredDocIDs collection, int collectionSize, int[] sample, long[] times)
throws IOException { throws IOException {
ScoredDocIDsIterator it = collection.iterator(); ScoredDocIDsIterator it = collection.iterator();
if (RandomSample.returnTimings) { if (returnTimings) {
times[0] = System.currentTimeMillis(); times[0] = System.currentTimeMillis();
} }
int sampleSize = sample.length; int sampleSize = sample.length;
int prime = RandomSample.findGoodStepSize(collectionSize, sampleSize); int prime = findGoodStepSize(collectionSize, sampleSize);
int mod = prime % collectionSize; int mod = prime % collectionSize;
if (RandomSample.returnTimings) { if (returnTimings) {
times[1] = System.currentTimeMillis(); times[1] = System.currentTimeMillis();
} }
int sampleCount = 0; int sampleCount = 0;
@ -158,10 +173,10 @@ public class RandomSample {
} }
sample[sampleCount++] = it.getDocID(); sample[sampleCount++] = it.getDocID();
} }
if (RandomSample.returnTimings) { if (returnTimings) {
times[2] = System.currentTimeMillis(); times[2] = System.currentTimeMillis();
} }
} // end RandomSample.sample1() }
/** /**
* Returns a value which will allow the caller to walk * Returns a value which will allow the caller to walk
@ -187,10 +202,10 @@ public class RandomSample {
i = collectionSize / sampleSize; i = collectionSize / sampleSize;
} }
do { do {
i = RandomSample.findNextPrimeAfter(i); i = findNextPrimeAfter(i);
} while (collectionSize % i == 0); } while (collectionSize % i == 0);
return i; return i;
} // end RandomSample.findGoodStepSize() }
/** /**
* Returns the first prime number that is larger than <code>n</code>. * Returns the first prime number that is larger than <code>n</code>.
@ -199,10 +214,10 @@ public class RandomSample {
*/ */
private static int findNextPrimeAfter(int n) { private static int findNextPrimeAfter(int n) {
n += (n % 2 == 0) ? 1 : 2; // next odd n += (n % 2 == 0) ? 1 : 2; // next odd
foundFactor: for (;; n += 2) { foundFactor: for (;; n += 2) { //TODO labels??!!
int sri = (int) (Math.sqrt(n)); int sri = (int) (Math.sqrt(n));
for (int primeIndex = 0; primeIndex < RandomSample.N_PRIMES; primeIndex++) { for (int primeIndex = 0; primeIndex < N_PRIMES; primeIndex++) {
int p = RandomSample.primes[primeIndex]; int p = primes[primeIndex];
if (p > sri) { if (p > sri) {
return n; return n;
} }
@ -210,7 +225,7 @@ public class RandomSample {
continue foundFactor; continue foundFactor;
} }
} }
for (int p = RandomSample.primes[RandomSample.N_PRIMES - 1] + 2;; p += 2) { for (int p = primes[N_PRIMES - 1] + 2;; p += 2) {
if (p > sri) { if (p > sri) {
return n; return n;
} }
@ -219,70 +234,17 @@ public class RandomSample {
} }
} }
} }
} // end RandomSample.findNextPrimeAfter()
/**
* Divides the values in <code>collection</code> into <code>numSubranges</code>
* subranges from <code>minValue</code> to <code>maxValue</code> and returns the
* number of values in each subrange. (For testing the flatness of distribution of
* a sample.)
* @param collection The collection of values to be counted.
* @param range The number of possible values.
* @param numSubranges How many intervals to divide the value range into.
*/
private static int[] countsBySubrange(int[] collection, int range, int numSubranges) {
int[] counts = new int[numSubranges];
Arrays.fill(counts, 0);
int numInSubrange = range / numSubranges;
for (int j = 0; j < collection.length; j++) {
counts[collection[j] / numInSubrange]++;
} }
return counts;
} // end RandomSample.countsBySubrange()
/**
* Factors <code>value</code> into primes.
*/
public static int[] factor(long value) {
ArrayList<Integer> list = new ArrayList<Integer>();
while (value > 1 && value % 2 == 0) {
list.add(2);
value /= 2;
}
long sqrt = Math.round(Math.sqrt(value));
for (int pIndex = 0, lim = RandomSample.primes.length; pIndex < lim; pIndex++) {
int p = RandomSample.primes[pIndex];
if (p >= sqrt) {
break;
}
while (value % p == 0) {
list.add(p);
value /= p;
sqrt = Math.round(Math.sqrt(value));
}
}
if (list.size() == 0 || value > Integer.MAX_VALUE) {
throw new RuntimeException("Prime or too large to factor: "+value);
}
if ((int)value > 1) {
list.add((int)value);
}
int[] factors = new int[list.size()];
for (int j = 0; j < factors.length; j++) {
factors[j] = list.get(j).intValue();
}
return factors;
} // end RandomSample.factor()
/** /**
* The first N_PRIMES primes, after 2. * The first N_PRIMES primes, after 2.
*/ */
private static final int N_PRIMES = 4000; private static final int N_PRIMES = 4000;
private static int[] primes = new int[RandomSample.N_PRIMES]; private static int[] primes = new int[N_PRIMES];
static { static {
RandomSample.primes[0] = 3; primes[0] = 3;
for (int count = 1; count < RandomSample.N_PRIMES; count++) { for (int count = 1; count < N_PRIMES; count++) {
primes[count] = RandomSample.findNextPrimeAfter(primes[count - 1]); primes[count] = findNextPrimeAfter(primes[count - 1]);
} }
} }
@ -307,7 +269,7 @@ public class RandomSample {
*/ */
private static void sample2(ScoredDocIDs collection, int collectionSize, int[] sample, long[] times) private static void sample2(ScoredDocIDs collection, int collectionSize, int[] sample, long[] times)
throws IOException { throws IOException {
if (RandomSample.returnTimings) { if (returnTimings) {
times[0] = System.currentTimeMillis(); times[0] = System.currentTimeMillis();
} }
int sampleSize = sample.length; int sampleSize = sample.length;
@ -320,7 +282,7 @@ public class RandomSample {
while (it.next()) { while (it.next()) {
pq.insertWithReuse((int)(it.getDocID() * PHI_32) & 0x7FFFFFFF); pq.insertWithReuse((int)(it.getDocID() * PHI_32) & 0x7FFFFFFF);
} }
if (RandomSample.returnTimings) { if (returnTimings) {
times[1] = System.currentTimeMillis(); times[1] = System.currentTimeMillis();
} }
/* /*
@ -330,10 +292,10 @@ public class RandomSample {
for (int si = 0; si < sampleSize; si++) { for (int si = 0; si < sampleSize; si++) {
sample[si] = (int)(((IntPriorityQueue.MI)(heap[si+1])).value * PHI_32I) & 0x7FFFFFFF; sample[si] = (int)(((IntPriorityQueue.MI)(heap[si+1])).value * PHI_32I) & 0x7FFFFFFF;
} }
if (RandomSample.returnTimings) { if (returnTimings) {
times[2] = System.currentTimeMillis(); times[2] = System.currentTimeMillis();
} }
} // end RandomSample.sample2() }
/** /**
* A bounded priority queue for Integers, to retain a specified number of * A bounded priority queue for Integers, to retain a specified number of
@ -358,7 +320,7 @@ public class RandomSample {
} }
this.mi.value = intval; this.mi.value = intval;
this.mi = (MI)this.insertWithOverflow(this.mi); this.mi = (MI)this.insertWithOverflow(this.mi);
} // end IntPriorityQueue.insertWithReuse() }
/** /**
* Returns the underlying data structure for faster access. Extracting elements * Returns the underlying data structure for faster access. Extracting elements
@ -386,19 +348,19 @@ public class RandomSample {
private static class MI { private static class MI {
MI() { } MI() { }
public int value; public int value;
} // end class RandomSample.IntPriorityQueue.MI }
/** /**
* The mutable integer instance for reuse after first overflow. * The mutable integer instance for reuse after first overflow.
*/ */
private MI mi; private MI mi;
} // end class RandomSample.IntPriorityQueue }
/** /**
* For specifying which sampling algorithm to use. * For specifying which sampling algorithm to use.
*/ */
public static class Algorithm { private enum Algorithm {
/** /**
* Specifies a methodical traversal algorithm, which is guaranteed to span the collection * Specifies a methodical traversal algorithm, which is guaranteed to span the collection
@ -410,7 +372,7 @@ public class RandomSample {
// TODO (Facet): This one produces a bimodal distribution (very flat around // TODO (Facet): This one produces a bimodal distribution (very flat around
// each peak!) for collection size 10M and sample sizes 10k and 10544. // each peak!) for collection size 10M and sample sizes 10k and 10544.
// Figure out why. // Figure out why.
public static final Algorithm TRAVERSAL = new Algorithm("Traversal"); TRAVERSAL,
/** /**
* Specifies a Fibonacci-style hash algorithm (see Knuth, S&S), which generates a less * Specifies a Fibonacci-style hash algorithm (see Knuth, S&S), which generates a less
@ -418,69 +380,25 @@ public class RandomSample {
* but requires a bounded priority queue the size of the sample, and creates an object * but requires a bounded priority queue the size of the sample, and creates an object
* containing a sampled value and its hash, for every element in the full set. * containing a sampled value and its hash, for every element in the full set.
*/ */
public static final Algorithm HASHING = new Algorithm("Hashing"); HASHING
/**
* Constructs an instance of an algorithm.
* @param name An ID for printing.
*/
private Algorithm(String name) {
this.name = name;
} }
/**
* Prints this algorithm's name.
*/
@Override
public String toString() {
return this.name;
}
/**
* The name of this algorithm, for printing.
*/
private String name;
} // end class RandomSample.Algorithm
/** /**
* For specifying whether to sort the sample. * For specifying whether to sort the sample.
*/ */
public static class Sorted { private enum Sorted {
/** /**
* Specifies sorting the resulting sample before returning. * Sort resulting sample before returning.
*/ */
public static final Sorted YES = new Sorted("sorted"); YES,
/** /**
* Specifies not sorting the resulting sample. *Do not sort the resulting sample.
*/ */
public static final Sorted NO = new Sorted("unsorted"); NO
/**
* Constructs an instance of a "sorted" selector.
* @param name An ID for printing.
*/
private Sorted(String name) {
this.name = name;
} }
/**
* Prints this selector's name.
*/
@Override
public String toString() {
return this.name;
}
/**
* The name of this selector, for printing.
*/
private String name;
} // end class RandomSample.Sorted
/** /**
* Magic number 1: prime closest to phi, in 32 bits. * Magic number 1: prime closest to phi, in 32 bits.
*/ */
@ -496,143 +414,4 @@ public class RandomSample {
*/ */
private static boolean returnTimings = false; private static boolean returnTimings = false;
/** }
* Self-test.
*/
public static void main(String[] args) throws Exception {
RandomSample.returnTimings = true;
/*
* Create an array of sequential integers, from which samples will be taken.
*/
final int COLLECTION_SIZE = 10 * 1000 * 1000;
ScoredDocIDs collection = createAllScoredDocs(COLLECTION_SIZE);
/*
* Factor PHI.
*
int[] factors = RandomSample.factor(PHI_32);
System.out.print("Factors of PHI_32: ");
for (int k : factors) {
System.out.print(k+", ");
}
System.out.println("");
* Verify inverse relationship of PHI & phi.
*
boolean inverseValid = true;
for (int j = 0; j < Integer.MAX_VALUE; j++) {
int k = (int)(j * PHI_32) & 0x7FFFFFFF;
int m = (int)(k * PHI_32I) & 0X7FFFFFFF;
if (j != m) {
System.out.println("Inverse not valid for "+j);
inverseValid = false;
}
}
System.out.println("Inverse valid? "+inverseValid);
*/
/*
* Take samples of various sizes from the full set, verify no duplicates,
* check flatness.
*/
int[] sampleSizes = {
10, 57, 100, 333, 1000, 2154, 10000
};
Algorithm[] algorithms = { Algorithm.HASHING, Algorithm.TRAVERSAL };
for (int sampleSize : sampleSizes) {
for (Algorithm algorithm : algorithms) {
System.out.println("Sample size " + sampleSize
+ ", algorithm " + algorithm + "...");
/*
* Take the sample.
*/
int[] sample = RandomSample.repeatableSample(
collection, COLLECTION_SIZE, sampleSize, algorithm, Sorted.YES);
/*
* Check for duplicates.
*/
boolean noDups = true;
for (int j = 0; j < sampleSize - 1; j++) {
if (sample[j] == sample[j + 1]) {
System.out.println("Duplicate value "
+ sample[j] + " at " + j + ", "
+ (j + 1));
noDups = false;
break;
}
}
if (noDups) {
System.out.println("No duplicates.");
}
if (algorithm == Algorithm.HASHING) {
System.out.print("Hashed sample, up to 100 of "+sampleSize+": ");
int lim = Math.min(100, sampleSize);
for (int k = 0; k < lim; k++) {
System.out.print(sample[k]+", ");
}
System.out.println("");
}
/*
* Check flatness of distribution in sample.
*/
final int N_INTERVALS = 100;
int[] counts = RandomSample.countsBySubrange(sample, COLLECTION_SIZE, N_INTERVALS);
int minCount = Integer.MAX_VALUE;
int maxCount = Integer.MIN_VALUE;
int avgCount = 0;
for (int j = 0; j < N_INTERVALS; j++) {
int count = counts[j];
if (count < minCount) {
minCount = count;
}
if (count > maxCount) {
maxCount = count;
}
avgCount += count;
}
avgCount /= N_INTERVALS;
System.out.println("Min, max, avg: "+minCount+", "+maxCount+", "+avgCount);
if (((double)minCount - avgCount)/avgCount < -0.05 && (minCount - avgCount) < -5) {
System.out.println("Not flat enough.");
} else if (((double)maxCount - avgCount)/avgCount > 0.05 && (maxCount - avgCount) > 5) {
System.out.println("Not flat enough.");
} else {
System.out.println("Flat enough.");
}
if (sampleSize == 10544 && algorithm == Algorithm.TRAVERSAL) {
System.out.print("Counts of interest: ");
for (int j = 0; j < N_INTERVALS; j++) {
System.out.print(counts[j]+", ");
}
System.out.println("");
}
}
}
System.out.println("Last prime is "
+ RandomSample.primes[RandomSample.N_PRIMES - 1]);
}
private static ScoredDocIDs createAllScoredDocs(final int COLLECTION_SIZE)
throws CorruptIndexException, LockObtainFailedException, IOException {
ScoredDocIDs collection;
IndexReader reader = null;
Directory ramDir = new RAMDirectory();
try {
IndexWriter writer = new IndexWriter(ramDir, new IndexWriterConfig(Version.LUCENE_30, new KeywordAnalyzer()));
for (int i = 0; i < COLLECTION_SIZE; i++) {
writer.addDocument(new Document());
}
writer.commit();
writer.close();
reader = IndexReader.open(ramDir);
collection = ScoredDocIdsUtils.createAllDocsScoredDocIDs(reader);
} finally {
if (reader != null) {
reader.close();
}
ramDir.close();
}
return collection;
}
} // end class RandomSample

View File

@ -1,8 +1,6 @@
package org.apache.lucene.facet.search.sampling; package org.apache.lucene.facet.search.sampling;
import java.io.IOException; import java.io.IOException;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
@ -15,8 +13,6 @@ import org.apache.lucene.facet.search.results.FacetResult;
import org.apache.lucene.facet.search.results.FacetResultNode; import org.apache.lucene.facet.search.results.FacetResultNode;
import org.apache.lucene.facet.search.results.MutableFacetResultNode; import org.apache.lucene.facet.search.results.MutableFacetResultNode;
import org.apache.lucene.facet.taxonomy.TaxonomyReader; import org.apache.lucene.facet.taxonomy.TaxonomyReader;
import org.apache.lucene.facet.util.RandomSample;
import org.apache.lucene.facet.util.ScoredDocIdsUtils;
/** /**
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
@ -48,11 +44,9 @@ import org.apache.lucene.facet.util.ScoredDocIdsUtils;
* *
* @lucene.experimental * @lucene.experimental
*/ */
public class Sampler { public abstract class Sampler {
private static final Logger logger = Logger.getLogger(Sampler.class.getName()); protected final SamplingParams samplingParams;
private final SamplingParams samplingParams;
/** /**
* Construct with {@link SamplingParams} * Construct with {@link SamplingParams}
@ -103,24 +97,18 @@ public class Sampler {
sampleSetSize = Math.max(sampleSetSize, samplingParams.getMinSampleSize()); sampleSetSize = Math.max(sampleSetSize, samplingParams.getMinSampleSize());
sampleSetSize = Math.min(sampleSetSize, samplingParams.getMaxSampleSize()); sampleSetSize = Math.min(sampleSetSize, samplingParams.getMaxSampleSize());
int[] sampleSet = null; return createSample(docids, actualSize, sampleSetSize);
try {
sampleSet = RandomSample.repeatableSample(docids, actualSize,
sampleSetSize);
} catch (IOException e) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "sampling failed: "+e.getMessage()+" - falling back to no sampling!", e);
}
return new SampleResult(docids, 1d);
} }
ScoredDocIDs sampled = ScoredDocIdsUtils.createScoredDocIDsSubset(docids, /**
sampleSet); * Create and return a sample of the input set
if (logger.isLoggable(Level.FINEST)) { * @param docids input set out of which a sample is to be created
logger.finest("******************** " + sampled.size()); * @param actualSize original size of set, prior to sampling
} * @param sampleSetSize required size of sample set
return new SampleResult(sampled, sampled.size()/(double)docids.size()); * @return sample of the input set in the required size
} */
protected abstract SampleResult createSample(ScoredDocIDs docids, int actualSize,
int sampleSetSize) throws IOException;
/** /**
* Get a fixer of sample facet accumulation results. Default implementation * Get a fixer of sample facet accumulation results. Default implementation

View File

@ -313,7 +313,7 @@ public abstract class FacetTestBase extends LuceneTestCase {
System.err.println("Results are not the same!"); System.err.println("Results are not the same!");
System.err.println("Expected:\n" + expectedResults); System.err.println("Expected:\n" + expectedResults);
System.err.println("Actual" + actualResults); System.err.println("Actual" + actualResults);
fail("Results are not the same!"); throw new NotSameResultError();
} }
} }
@ -325,4 +325,12 @@ public abstract class FacetTestBase extends LuceneTestCase {
} }
return sb.toString().replaceAll("Residue:.*.0", "").replaceAll("Num valid Descendants.*", ""); return sb.toString().replaceAll("Residue:.*.0", "").replaceAll("Num valid Descendants.*", "");
} }
/** Special Error class for ability to ignore only this error and retry... */
public static class NotSameResultError extends Error {
public NotSameResultError() {
super("Results are not the same!");
}
}
} }

View File

@ -2,6 +2,7 @@ package org.apache.lucene.facet.search.sampling;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Random;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
@ -41,7 +42,7 @@ public abstract class BaseSampleTestTopK extends BaseTestTopK {
protected static final int K = 2; protected static final int K = 2;
/** since there is a chance that this test would fail even if the code is correct, retry the sampling */ /** since there is a chance that this test would fail even if the code is correct, retry the sampling */
protected static final int RETRIES = 4; protected static final int RETRIES = 10;
protected abstract FacetsAccumulator getSamplingAccumulator(Sampler sampler, protected abstract FacetsAccumulator getSamplingAccumulator(Sampler sampler,
TaxonomyReader taxoReader, IndexReader indexReader, TaxonomyReader taxoReader, IndexReader indexReader,
@ -53,9 +54,10 @@ public abstract class BaseSampleTestTopK extends BaseTestTopK {
* is performed. The results are compared to non-sampled ones. * is performed. The results are compared to non-sampled ones.
*/ */
public void testCountUsingSamping() throws Exception, IOException { public void testCountUsingSamping() throws Exception, IOException {
boolean useRandomSampler = random.nextBoolean();
for (int partitionSize : partitionSizes) { for (int partitionSize : partitionSizes) {
try {
initIndex(partitionSize); initIndex(partitionSize);
// Get all of the documents and run the query, then do different // Get all of the documents and run the query, then do different
// facet counts and compare to control // facet counts and compare to control
Query q = new TermQuery(new Term(CONTENT_FIELD, BETA)); // 90% of the docs Query q = new TermQuery(new Term(CONTENT_FIELD, BETA)); // 90% of the docs
@ -68,36 +70,38 @@ public abstract class BaseSampleTestTopK extends BaseTestTopK {
List<FacetResult> expectedResults = fc.getFacetResults(); List<FacetResult> expectedResults = fc.getFacetResults();
// complement with sampling!
final Sampler sampler = createSampler(docCollector.getScoredDocIDs());
FacetSearchParams samplingSearchParams = searchParamsWithRequests(K, partitionSize); FacetSearchParams samplingSearchParams = searchParamsWithRequests(K, partitionSize);
// try several times in case of failure, because the test has a chance to fail
// if the top K facets are not sufficiently common with the sample set
for (int nTrial=0; nTrial<RETRIES; nTrial++) {
try {
// complement with sampling!
final Sampler sampler = createSampler(nTrial, docCollector.getScoredDocIDs(), useRandomSampler);
assertSampling(expectedResults, q, sampler, samplingSearchParams, false); assertSampling(expectedResults, q, sampler, samplingSearchParams, false);
assertSampling(expectedResults, q, sampler, samplingSearchParams, true); assertSampling(expectedResults, q, sampler, samplingSearchParams, true);
break; // succeeded
} catch (NotSameResultError e) {
if (nTrial>=RETRIES-1) {
throw e; // no more retries allowed, must fail
}
}
}
} finally {
closeAll(); closeAll();
} }
} }
}
private void assertSampling(List<FacetResult> expected, Query q, Sampler sampler, FacetSearchParams params, boolean complement) throws Exception { private void assertSampling(List<FacetResult> expected, Query q, Sampler sampler, FacetSearchParams params, boolean complement) throws Exception {
// try several times in case of failure, because the test has a chance to fail FacetsCollector samplingFC = samplingCollector(complement, sampler, params);
// if the top K facets are not sufficiently common with the sample set
for (int n=RETRIES; n>0; n--) {
FacetsCollector samplingFC = samplingCollector(false, sampler, params);
searcher.search(q, samplingFC); searcher.search(q, samplingFC);
List<FacetResult> sampledResults = samplingFC.getFacetResults(); List<FacetResult> sampledResults = samplingFC.getFacetResults();
try {
assertSameResults(expected, sampledResults); assertSameResults(expected, sampledResults);
break; // succeeded
} catch (Exception e) {
if (n<=1) { // otherwise try again
throw e;
}
}
}
} }
private FacetsCollector samplingCollector( private FacetsCollector samplingCollector(
@ -117,14 +121,19 @@ public abstract class BaseSampleTestTopK extends BaseTestTopK {
return samplingFC; return samplingFC;
} }
private Sampler createSampler(ScoredDocIDs scoredDocIDs) { private Sampler createSampler(int nTrial, ScoredDocIDs scoredDocIDs, boolean useRandomSampler) {
SamplingParams samplingParams = new SamplingParams(); SamplingParams samplingParams = new SamplingParams();
samplingParams.setSampleRatio(0.8);
samplingParams.setMinSampleSize(100); final double retryFactor = Math.pow(1.01, nTrial);
samplingParams.setMaxSampleSize(10000); samplingParams.setSampleRatio(0.8 * retryFactor);
samplingParams.setMinSampleSize((int) (100 * retryFactor));
samplingParams.setMaxSampleSize((int) (10000 * retryFactor));
samplingParams.setOversampleFactor(5.0 * retryFactor);
samplingParams.setSampingThreshold(11000); //force sampling samplingParams.setSampingThreshold(11000); //force sampling
samplingParams.setOversampleFactor(5.0); Sampler sampler = useRandomSampler ?
Sampler sampler = new Sampler(samplingParams); new RandomSampler(samplingParams, new Random(random.nextLong())) :
new RepeatableSampler(samplingParams);
assertTrue("must enable sampling for this test!",sampler.shouldSample(scoredDocIDs)); assertTrue("must enable sampling for this test!",sampler.shouldSample(scoredDocIDs));
return sampler; return sampler;
} }