LUCENE-10444: Support alternate aggregation functions in association facets (#718)

This commit is contained in:
Greg Miller 2022-04-06 14:51:06 -07:00 committed by GitHub
parent 9eeef080e5
commit f870edf2fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 835 additions and 415 deletions

View File

@ -76,6 +76,8 @@ New Features
* LUCENE-10456: Implement rewrite and Weight#count for MultiRangeQuery
by merging overlapping ranges . (Jianping Weng)
* LUCENE-10444: Support alternate aggregation functions in association facets. (Greg Miller)
Improvements
---------------------

View File

@ -26,10 +26,11 @@ import org.apache.lucene.facet.FacetResult;
import org.apache.lucene.facet.Facets;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.facet.taxonomy.AssociationAggregationFunction;
import org.apache.lucene.facet.taxonomy.FloatAssociationFacetField;
import org.apache.lucene.facet.taxonomy.IntAssociationFacetField;
import org.apache.lucene.facet.taxonomy.TaxonomyFacetSumFloatAssociations;
import org.apache.lucene.facet.taxonomy.TaxonomyFacetSumIntAssociations;
import org.apache.lucene.facet.taxonomy.TaxonomyFacetFloatAssociations;
import org.apache.lucene.facet.taxonomy.TaxonomyFacetIntAssociations;
import org.apache.lucene.facet.taxonomy.TaxonomyReader;
import org.apache.lucene.facet.taxonomy.directory.DirectoryTaxonomyReader;
import org.apache.lucene.facet.taxonomy.directory.DirectoryTaxonomyWriter;
@ -102,8 +103,12 @@ public class AssociationsFacetsExample {
// you'd use a "normal" query:
FacetsCollector.search(searcher, new MatchAllDocsQuery(), 10, fc);
Facets tags = new TaxonomyFacetSumIntAssociations("$tags", taxoReader, config, fc);
Facets genre = new TaxonomyFacetSumFloatAssociations("$genre", taxoReader, config, fc);
Facets tags =
new TaxonomyFacetIntAssociations(
"$tags", taxoReader, config, fc, AssociationAggregationFunction.SUM);
Facets genre =
new TaxonomyFacetFloatAssociations(
"$genre", taxoReader, config, fc, AssociationAggregationFunction.SUM);
// Retrieve results
List<FacetResult> results = new ArrayList<>();
@ -132,7 +137,9 @@ public class AssociationsFacetsExample {
FacetsCollector.search(searcher, q, 10, fc);
// Retrieve results
Facets facets = new TaxonomyFacetSumFloatAssociations("$genre", taxoReader, config, fc);
Facets facets =
new TaxonomyFacetFloatAssociations(
"$genre", taxoReader, config, fc, AssociationAggregationFunction.SUM);
FacetResult result = facets.getTopChildren(10, "genre");
indexReader.close();

View File

@ -31,7 +31,8 @@ import org.apache.lucene.facet.FacetResult;
import org.apache.lucene.facet.Facets;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.facet.taxonomy.TaxonomyFacetSumValueSource;
import org.apache.lucene.facet.taxonomy.AssociationAggregationFunction;
import org.apache.lucene.facet.taxonomy.TaxonomyFacetFloatAssociations;
import org.apache.lucene.facet.taxonomy.TaxonomyReader;
import org.apache.lucene.facet.taxonomy.directory.DirectoryTaxonomyReader;
import org.apache.lucene.facet.taxonomy.directory.DirectoryTaxonomyWriter;
@ -105,8 +106,12 @@ public class ExpressionAggregationFacetsExample {
// Retrieve results
Facets facets =
new TaxonomyFacetSumValueSource(
taxoReader, config, fc, expr.getDoubleValuesSource(bindings));
new TaxonomyFacetFloatAssociations(
taxoReader,
config,
fc,
AssociationAggregationFunction.SUM,
expr.getDoubleValuesSource(bindings));
FacetResult result = facets.getTopChildren(10, "A");
indexReader.close();

View File

@ -24,9 +24,8 @@
* <li>Taxonomy-based methods rely on a separate taxonomy index to map hierarchical facet paths to
* global int ordinals for fast counting at search time; these methods can compute counts
* (({@link org.apache.lucene.facet.taxonomy.FastTaxonomyFacetCounts}) aggregate long or
* double values {@link org.apache.lucene.facet.taxonomy.TaxonomyFacetSumIntAssociations},
* {@link org.apache.lucene.facet.taxonomy.TaxonomyFacetSumFloatAssociations}, {@link
* org.apache.lucene.facet.taxonomy.TaxonomyFacetSumValueSource}. Add {@link
* double values {@link org.apache.lucene.facet.taxonomy.TaxonomyFacetIntAssociations}, {@link
* org.apache.lucene.facet.taxonomy.TaxonomyFacetFloatAssociations}. Add {@link
* org.apache.lucene.facet.FacetField} or {@link
* org.apache.lucene.facet.taxonomy.AssociationFacetField} to your documents at index time to
* use taxonomy-based methods.

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.facet.taxonomy;
/**
* Specify aggregation logic used in {@link TaxonomyFacetIntAssociations} and {@link
* TaxonomyFacetFloatAssociations}.
*/
public abstract class AssociationAggregationFunction {
// TODO: Would be nice to add support for MIN as well here, but there are a number of places
// in our facet implementations where we attribute special meaning to 0 and assume that valid
// values are always positive. I think we'd want to break that assumption for MIN to really
// make sense.
/** Sole constructor. */
protected AssociationAggregationFunction() {}
/** Implement aggregation logic for integers */
public abstract int aggregate(int existingVal, int newVal);
/** Implement aggregation logic for floats */
public abstract float aggregate(float existingVal, float newVal);
/** Aggregation that computes the maximum value */
public static final AssociationAggregationFunction MAX =
new AssociationAggregationFunction() {
@Override
public int aggregate(int existingVal, int newVal) {
return Math.max(existingVal, newVal);
}
@Override
public float aggregate(float existingVal, float newVal) {
return Math.max(existingVal, newVal);
}
};
/** Aggregation that computes the sum */
public static final AssociationAggregationFunction SUM =
new AssociationAggregationFunction() {
@Override
public int aggregate(int existingVal, int newVal) {
return existingVal + newVal;
}
@Override
public float aggregate(float existingVal, float newVal) {
return existingVal + newVal;
}
};
}

View File

@ -53,7 +53,7 @@ public class FastTaxonomyFacetCounts extends IntTaxonomyFacets {
public FastTaxonomyFacetCounts(
String indexFieldName, TaxonomyReader taxoReader, FacetsConfig config, FacetsCollector fc)
throws IOException {
super(indexFieldName, taxoReader, config, fc);
super(indexFieldName, taxoReader, config, AssociationAggregationFunction.SUM, fc);
count(fc.getMatchingDocs());
}
@ -65,7 +65,7 @@ public class FastTaxonomyFacetCounts extends IntTaxonomyFacets {
public FastTaxonomyFacetCounts(
String indexFieldName, IndexReader reader, TaxonomyReader taxoReader, FacetsConfig config)
throws IOException {
super(indexFieldName, taxoReader, config, null);
super(indexFieldName, taxoReader, config, AssociationAggregationFunction.SUM, null);
countAll(reader);
}

View File

@ -22,7 +22,7 @@ import org.apache.lucene.util.BytesRef;
/**
* Add an instance of this to your {@link Document} to add a facet label associated with a float.
* Use {@link TaxonomyFacetSumFloatAssociations} to aggregate float values per facet label at search
* Use {@link TaxonomyFacetFloatAssociations} to aggregate float values per facet label at search
* time.
*
* @lucene.experimental

View File

@ -29,13 +29,21 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {
// TODO: also use native hash map for sparse collection, like IntTaxonomyFacets
/** Aggregation function used for combining values. */
final AssociationAggregationFunction aggregationFunction;
/** Per-ordinal value. */
final float[] values;
/** Sole constructor. */
FloatTaxonomyFacets(String indexFieldName, TaxonomyReader taxoReader, FacetsConfig config)
FloatTaxonomyFacets(
String indexFieldName,
TaxonomyReader taxoReader,
AssociationAggregationFunction aggregationFunction,
FacetsConfig config)
throws IOException {
super(indexFieldName, taxoReader, config);
this.aggregationFunction = aggregationFunction;
values = new float[taxoReader.getSize()];
}
@ -49,7 +57,9 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {
if (ft.hierarchical && ft.multiValued == false) {
int dimRootOrd = taxoReader.getOrdinal(new FacetLabel(dim));
assert dimRootOrd > 0;
values[dimRootOrd] += rollup(children[dimRootOrd]);
float newValue =
aggregationFunction.aggregate(values[dimRootOrd], rollup(children[dimRootOrd]));
values[dimRootOrd] = newValue;
}
}
}
@ -57,14 +67,14 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {
private float rollup(int ord) throws IOException {
int[] children = getChildren();
int[] siblings = getSiblings();
float sum = 0;
float aggregationValue = 0f;
while (ord != TaxonomyReader.INVALID_ORDINAL) {
float childValue = values[ord] + rollup(children[ord]);
float childValue = aggregationFunction.aggregate(values[ord], rollup(children[ord]));
values[ord] = childValue;
sum += childValue;
aggregationValue = aggregationFunction.aggregate(aggregationValue, childValue);
ord = siblings[ord];
}
return sum;
return aggregationValue;
}
@Override
@ -104,13 +114,13 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {
int[] siblings = getSiblings();
int ord = children[dimOrd];
float sumValues = 0;
float aggregatedValue = 0;
int childCount = 0;
TopOrdAndFloatQueue.OrdAndValue reuse = null;
while (ord != TaxonomyReader.INVALID_ORDINAL) {
if (values[ord] > 0) {
sumValues += values[ord];
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, values[ord]);
childCount++;
if (values[ord] > bottomValue) {
if (reuse == null) {
@ -128,16 +138,16 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {
ord = siblings[ord];
}
if (sumValues == 0) {
if (aggregatedValue == 0) {
return null;
}
if (dimConfig.multiValued) {
if (dimConfig.requireDimCount) {
sumValues = values[dimOrd];
aggregatedValue = values[dimOrd];
} else {
// Our sum'd count is not correct, in general:
sumValues = -1;
aggregatedValue = -1;
}
} else {
// Our sum'd dim count is accurate, so we keep it
@ -158,6 +168,6 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {
labelValues[i] = new LabelAndValue(bulkPath[i].components[cp.length], values[i]);
}
return new FacetResult(dim, path, sumValues, labelValues, childCount);
return new FacetResult(dim, path, aggregatedValue, labelValues, childCount);
}
}

View File

@ -23,7 +23,7 @@ import org.apache.lucene.util.BytesRef;
/**
* Add an instance of this to your {@link Document} to add a facet label associated with an int. Use
* {@link TaxonomyFacetSumIntAssociations} to aggregate int values per facet label at search time.
* {@link TaxonomyFacetIntAssociations} to aggregate int values per facet label at search time.
*
* @lucene.experimental
*/

View File

@ -31,6 +31,9 @@ import org.apache.lucene.facet.TopOrdAndIntQueue;
/** Base class for all taxonomy-based facets that aggregate to a per-ords int[]. */
abstract class IntTaxonomyFacets extends TaxonomyFacets {
/** Aggregation function used for combining values. */
final AssociationAggregationFunction aggregationFunction;
/** Dense ordinal values. */
final int[] values;
@ -39,9 +42,14 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
/** Sole constructor. */
IntTaxonomyFacets(
String indexFieldName, TaxonomyReader taxoReader, FacetsConfig config, FacetsCollector fc)
String indexFieldName,
TaxonomyReader taxoReader,
FacetsConfig config,
AssociationAggregationFunction aggregationFunction,
FacetsCollector fc)
throws IOException {
super(indexFieldName, taxoReader, config);
this.aggregationFunction = aggregationFunction;
if (useHashTable(fc, taxoReader)) {
sparseValues = new IntIntHashMap();
@ -52,12 +60,12 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
}
}
/** Increment the count for this ordinal by {@code amount}.. */
void increment(int ordinal, int amount) {
/** Set the count for this ordinal to {@code newValue}. */
void setValue(int ordinal, int newValue) {
if (sparseValues != null) {
sparseValues.addTo(ordinal, amount);
sparseValues.put(ordinal, newValue);
} else {
values[ordinal] += amount;
values[ordinal] = newValue;
}
}
@ -86,7 +94,9 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
// lazy init
children = getChildren();
}
increment(dimRootOrd, rollup(children[dimRootOrd]));
int currentValue = getValue(dimRootOrd);
int newValue = aggregationFunction.aggregate(currentValue, rollup(children[dimRootOrd]));
setValue(dimRootOrd, newValue);
}
}
}
@ -95,13 +105,15 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
private int rollup(int ord) throws IOException {
int[] children = getChildren();
int[] siblings = getSiblings();
int sum = 0;
int aggregatedValue = 0;
while (ord != TaxonomyReader.INVALID_ORDINAL) {
increment(ord, rollup(children[ord]));
sum += getValue(ord);
int currentValue = getValue(ord);
int newValue = aggregationFunction.aggregate(currentValue, rollup(children[ord]));
setValue(ord, newValue);
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, getValue(ord));
ord = siblings[ord];
}
return sum;
return aggregatedValue;
}
/** Return true if a sparse hash table should be used for counting, instead of a dense int[]. */
@ -161,7 +173,7 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
int bottomValue = 0;
int totValue = 0;
int aggregatedValue = 0;
int childCount = 0;
TopOrdAndIntQueue.OrdAndValue reuse = null;
@ -171,17 +183,17 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
if (sparseValues != null) {
for (IntIntCursor c : sparseValues) {
int count = c.value;
int value = c.value;
int ord = c.key;
if (parents[ord] == dimOrd && count > 0) {
totValue += count;
if (parents[ord] == dimOrd && value > 0) {
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, value);
childCount++;
if (count > bottomValue) {
if (value > bottomValue) {
if (reuse == null) {
reuse = new TopOrdAndIntQueue.OrdAndValue();
}
reuse.ord = ord;
reuse.value = count;
reuse.value = value;
reuse = q.insertWithOverflow(reuse);
if (q.size() == topN) {
bottomValue = q.top().value;
@ -196,7 +208,7 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
while (ord != TaxonomyReader.INVALID_ORDINAL) {
int value = values[ord];
if (value > 0) {
totValue += value;
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, value);
childCount++;
if (value > bottomValue) {
if (reuse == null) {
@ -215,16 +227,16 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
}
}
if (totValue == 0) {
if (aggregatedValue == 0) {
return null;
}
if (dimConfig.multiValued) {
if (dimConfig.requireDimCount) {
totValue = getValue(dimOrd);
aggregatedValue = getValue(dimOrd);
} else {
// Our sum'd value is not correct, in general:
totValue = -1;
aggregatedValue = -1;
}
} else {
// Our sum'd dim value is accurate, so we keep it
@ -245,6 +257,6 @@ abstract class IntTaxonomyFacets extends TaxonomyFacets {
labelValues[i] = new LabelAndValue(bulkPath[i].components[cp.length], values[i]);
}
return new FacetResult(dim, path, totValue, labelValues, childCount);
return new FacetResult(dim, path, aggregatedValue, labelValues, childCount);
}
}

View File

@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.facet.FacetField;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollector.MatchingDocs;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
/**
* Aggregates float values associated with facet fields. Supports two different approaches:
*
* <ol>
* <li>Fields can be indexed with {@link FloatAssociationFacetField}, associating weights with
* facet values at indexing time.
* <li>Fields can be indexed with {@link FacetField} and a {@link DoubleValuesSource} can
* dynamically supply a weight from each doc. With this approach, the document's weight gets
* contributed to each facet value associated with the doc.
* </ol>
*
* Aggregation logic is supplied by the provided {@link FloatAssociationFacetField}.
*
* @lucene.experimental
*/
public class TaxonomyFacetFloatAssociations extends FloatTaxonomyFacets {
/** Create {@code TaxonomyFacetFloatAssociations} against the default index field. */
public TaxonomyFacetFloatAssociations(
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction)
throws IOException {
this(FacetsConfig.DEFAULT_INDEX_FIELD_NAME, taxoReader, config, fc, aggregationFunction);
}
/**
* Create {@code TaxonomyFacetFloatAssociations} against the default index field. Sources values
* from the provided {@code valuesSource}.
*/
public TaxonomyFacetFloatAssociations(
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction,
DoubleValuesSource valuesSource)
throws IOException {
this(
FacetsConfig.DEFAULT_INDEX_FIELD_NAME,
taxoReader,
config,
fc,
aggregationFunction,
valuesSource);
}
/** Create {@code TaxonomyFacetFloatAssociations} against the specified index field. */
public TaxonomyFacetFloatAssociations(
String indexFieldName,
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction)
throws IOException {
super(indexFieldName, taxoReader, aggregationFunction, config);
aggregateValues(aggregationFunction, fc.getMatchingDocs());
}
/**
* Create {@code TaxonomyFacetFloatAssociations} against the specified index field. Sources values
* from the provided {@code valuesSource}.
*/
public TaxonomyFacetFloatAssociations(
String indexFieldName,
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction,
DoubleValuesSource valuesSource)
throws IOException {
super(indexFieldName, taxoReader, aggregationFunction, config);
aggregateValues(aggregationFunction, fc.getMatchingDocs(), fc.getKeepScores(), valuesSource);
}
private static DoubleValues scores(MatchingDocs hits) {
return new DoubleValues() {
int index = -1;
@Override
public double doubleValue() throws IOException {
return hits.scores[index];
}
@Override
public boolean advanceExact(int doc) throws IOException {
index = doc;
return true;
}
};
}
/** Aggregate using the provided {@code DoubleValuesSource}. */
private void aggregateValues(
AssociationAggregationFunction aggregationFunction,
List<MatchingDocs> matchingDocs,
boolean keepScores,
DoubleValuesSource valueSource)
throws IOException {
for (MatchingDocs hits : matchingDocs) {
SortedNumericDocValues ordinalValues =
DocValues.getSortedNumeric(hits.context.reader(), indexFieldName);
DoubleValues scores = keepScores ? scores(hits) : null;
DoubleValues functionValues = valueSource.getValues(hits.context, scores);
DocIdSetIterator it =
ConjunctionUtils.intersectIterators(List.of(hits.bits.iterator(), ordinalValues));
for (int doc = it.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = it.nextDoc()) {
if (functionValues.advanceExact(doc)) {
float value = (float) functionValues.doubleValue();
int ordinalCount = ordinalValues.docValueCount();
for (int i = 0; i < ordinalCount; i++) {
int ord = (int) ordinalValues.nextValue();
float newValue = aggregationFunction.aggregate(values[ord], value);
values[ord] = newValue;
}
}
}
}
// Hierarchical dimensions are supported when using a value source, so we need to rollup:
rollup();
}
/** Aggregate from indexed association values. */
private void aggregateValues(
AssociationAggregationFunction aggregationFunction, List<MatchingDocs> matchingDocs)
throws IOException {
for (MatchingDocs hits : matchingDocs) {
BinaryDocValues dv = DocValues.getBinary(hits.context.reader(), indexFieldName);
DocIdSetIterator it =
ConjunctionUtils.intersectIterators(Arrays.asList(hits.bits.iterator(), dv));
for (int doc = it.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = it.nextDoc()) {
final BytesRef bytesRef = dv.binaryValue();
byte[] bytes = bytesRef.bytes;
int end = bytesRef.offset + bytesRef.length;
int offset = bytesRef.offset;
while (offset < end) {
int ord = (int) BitUtil.VH_BE_INT.get(bytes, offset);
offset += 4;
float value = (float) BitUtil.VH_BE_FLOAT.get(bytes, offset);
offset += 4;
float newValue = aggregationFunction.aggregate(values[ord], value);
values[ord] = newValue;
}
}
}
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollector.MatchingDocs;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
/**
* Aggregates int values previously indexed with {@link IntAssociationFacetField}, assuming the
* default encoding. The aggregation function is defined by a provided {@link
* AssociationAggregationFunction}.
*
* @lucene.experimental
*/
public class TaxonomyFacetIntAssociations extends IntTaxonomyFacets {
/** Create {@code TaxonomyFacetIntAssociations} against the default index field. */
public TaxonomyFacetIntAssociations(
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction)
throws IOException {
this(FacetsConfig.DEFAULT_INDEX_FIELD_NAME, taxoReader, config, fc, aggregationFunction);
}
/** Create {@code TaxonomyFacetIntAssociations} against the specified index field. */
public TaxonomyFacetIntAssociations(
String indexFieldName,
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction)
throws IOException {
super(indexFieldName, taxoReader, config, aggregationFunction, fc);
aggregateValues(aggregationFunction, fc.getMatchingDocs());
}
private void aggregateValues(
AssociationAggregationFunction aggregationFunction, List<MatchingDocs> matchingDocs)
throws IOException {
for (MatchingDocs hits : matchingDocs) {
BinaryDocValues dv = DocValues.getBinary(hits.context.reader(), indexFieldName);
DocIdSetIterator it = ConjunctionUtils.intersectIterators(List.of(hits.bits.iterator(), dv));
for (int doc = it.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = it.nextDoc()) {
final BytesRef bytesRef = dv.binaryValue();
byte[] bytes = bytesRef.bytes;
int end = bytesRef.offset + bytesRef.length;
int offset = bytesRef.offset;
while (offset < end) {
int ord = (int) BitUtil.VH_BE_INT.get(bytes, offset);
offset += 4;
int value = (int) BitUtil.VH_BE_INT.get(bytes, offset);
offset += 4;
// TODO: Can we optimize the null check in setValue? See LUCENE-10373.
int currentValue = getValue(ord);
int newValue = aggregationFunction.aggregate(currentValue, value);
setValue(ord, newValue);
}
}
}
}
}

View File

@ -1,82 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollector.MatchingDocs;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
/**
* Aggregates sum of int values previously indexed with {@link FloatAssociationFacetField}, assuming
* the default encoding.
*
* @lucene.experimental
*/
public class TaxonomyFacetSumFloatAssociations extends FloatTaxonomyFacets {
/** Create {@code TaxonomyFacetSumFloatAssociations} against the default index field. */
public TaxonomyFacetSumFloatAssociations(
TaxonomyReader taxoReader, FacetsConfig config, FacetsCollector fc) throws IOException {
this(FacetsConfig.DEFAULT_INDEX_FIELD_NAME, taxoReader, config, fc);
}
/** Create {@code TaxonomyFacetSumFloatAssociations} against the specified index field. */
public TaxonomyFacetSumFloatAssociations(
String indexFieldName, TaxonomyReader taxoReader, FacetsConfig config, FacetsCollector fc)
throws IOException {
super(indexFieldName, taxoReader, config);
sumValues(fc.getMatchingDocs());
}
private final void sumValues(List<MatchingDocs> matchingDocs) throws IOException {
// System.out.println("count matchingDocs=" + matchingDocs + " facetsField=" + facetsFieldName);
for (MatchingDocs hits : matchingDocs) {
BinaryDocValues dv = hits.context.reader().getBinaryDocValues(indexFieldName);
if (dv == null) { // this reader does not have DocValues for the requested category list
continue;
}
DocIdSetIterator docs = hits.bits.iterator();
int doc;
while ((doc = docs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (dv.docID() < doc) {
dv.advance(doc);
}
if (dv.docID() == doc) {
final BytesRef bytesRef = dv.binaryValue();
byte[] bytes = bytesRef.bytes;
int end = bytesRef.offset + bytesRef.length;
int offset = bytesRef.offset;
while (offset < end) {
int ord = (int) BitUtil.VH_BE_INT.get(bytes, offset);
offset += 4;
float value = (float) BitUtil.VH_BE_FLOAT.get(bytes, offset);
offset += 4;
values[ord] += value;
}
}
}
}
}
}

View File

@ -1,83 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollector.MatchingDocs;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
/**
* Aggregates sum of int values previously indexed with {@link IntAssociationFacetField}, assuming
* the default encoding.
*
* @lucene.experimental
*/
public class TaxonomyFacetSumIntAssociations extends IntTaxonomyFacets {
/** Create {@code TaxonomyFacetSumIntAssociations} against the default index field. */
public TaxonomyFacetSumIntAssociations(
TaxonomyReader taxoReader, FacetsConfig config, FacetsCollector fc) throws IOException {
this(FacetsConfig.DEFAULT_INDEX_FIELD_NAME, taxoReader, config, fc);
}
/** Create {@code TaxonomyFacetSumIntAssociations} against the specified index field. */
public TaxonomyFacetSumIntAssociations(
String indexFieldName, TaxonomyReader taxoReader, FacetsConfig config, FacetsCollector fc)
throws IOException {
super(indexFieldName, taxoReader, config, fc);
sumValues(fc.getMatchingDocs());
}
private final void sumValues(List<MatchingDocs> matchingDocs) throws IOException {
// System.out.println("count matchingDocs=" + matchingDocs + " facetsField=" + facetsFieldName);
for (MatchingDocs hits : matchingDocs) {
BinaryDocValues dv = hits.context.reader().getBinaryDocValues(indexFieldName);
if (dv == null) { // this reader does not have DocValues for the requested category list
continue;
}
DocIdSetIterator docs = hits.bits.iterator();
int doc;
while ((doc = docs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (dv.docID() < doc) {
dv.advance(doc);
}
if (dv.docID() == doc) {
final BytesRef bytesRef = dv.binaryValue();
byte[] bytes = bytesRef.bytes;
int end = bytesRef.offset + bytesRef.length;
int offset = bytesRef.offset;
while (offset < end) {
int ord = (int) BitUtil.VH_BE_INT.get(bytes, offset);
offset += 4;
int value = (int) BitUtil.VH_BE_INT.get(bytes, offset);
offset += 4;
// TODO: Can we optimize the null check in increment? See LUCENE-10373.
increment(ord, value);
}
}
}
}
}
}

View File

@ -1,109 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollector.MatchingDocs;
import org.apache.lucene.facet.FacetsConfig;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
/**
* Aggregates sum of values from {@link DoubleValues#doubleValue()}, for each facet label.
*
* @lucene.experimental
*/
public class TaxonomyFacetSumValueSource extends FloatTaxonomyFacets {
/**
* Aggreggates double facet values from the provided {@link DoubleValuesSource}, pulling ordinals
* from the default indexed facet field {@link FacetsConfig#DEFAULT_INDEX_FIELD_NAME}.
*/
public TaxonomyFacetSumValueSource(
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
DoubleValuesSource valueSource)
throws IOException {
this(FacetsConfig.DEFAULT_INDEX_FIELD_NAME, taxoReader, config, fc, valueSource);
}
/**
* Aggreggates double facet values from the provided {@link DoubleValuesSource}, pulling ordinals
* from the specified indexed facet field.
*/
public TaxonomyFacetSumValueSource(
String indexField,
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
DoubleValuesSource valueSource)
throws IOException {
super(indexField, taxoReader, config);
sumValues(fc.getMatchingDocs(), fc.getKeepScores(), valueSource);
}
private static DoubleValues scores(MatchingDocs hits) {
return new DoubleValues() {
int index = -1;
@Override
public double doubleValue() throws IOException {
return hits.scores[index];
}
@Override
public boolean advanceExact(int doc) throws IOException {
index = doc;
return true;
}
};
}
private void sumValues(
List<MatchingDocs> matchingDocs, boolean keepScores, DoubleValuesSource valueSource)
throws IOException {
for (MatchingDocs hits : matchingDocs) {
SortedNumericDocValues ordinalValues =
DocValues.getSortedNumeric(hits.context.reader(), indexFieldName);
DoubleValues scores = keepScores ? scores(hits) : null;
DoubleValues functionValues = valueSource.getValues(hits.context, scores);
DocIdSetIterator it =
ConjunctionUtils.intersectIterators(List.of(hits.bits.iterator(), ordinalValues));
for (int doc = it.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = it.nextDoc()) {
if (functionValues.advanceExact(doc)) {
float value = (float) functionValues.doubleValue();
int ordinalCount = ordinalValues.docValueCount();
for (int i = 0; i < ordinalCount; i++) {
values[(int) ordinalValues.nextValue()] += value;
}
}
}
}
rollup();
}
}

View File

@ -16,8 +16,14 @@
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.facet.DrillDownQuery;
import org.apache.lucene.facet.FacetResult;
import org.apache.lucene.facet.FacetTestCase;
import org.apache.lucene.facet.Facets;
import org.apache.lucene.facet.FacetsCollector;
@ -44,6 +50,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
private static FacetsConfig config;
private static Map<String, List<Integer>> randomIntValues;
private static Map<String, List<Float>> randomFloatValues;
private static Map<String, List<Integer>> randomIntSingleValued;
private static Map<String, List<Float>> randomFloatSingleValued;
@BeforeClass
public static void beforeClass() throws Exception {
dir = newDirectory();
@ -56,8 +67,14 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
config = new FacetsConfig();
config.setIndexFieldName("int", "$facets.int");
config.setMultiValued("int", true);
config.setIndexFieldName("int_random", "$facets.int");
config.setMultiValued("int_random", true);
config.setIndexFieldName("int_single_valued", "$facets.int");
config.setIndexFieldName("float", "$facets.float");
config.setMultiValued("float", true);
config.setIndexFieldName("float_random", "$facets.float");
config.setMultiValued("float_random", true);
config.setIndexFieldName("float_single_valued", "$facets.float");
RandomIndexWriter writer = new RandomIndexWriter(random(), dir);
@ -77,6 +94,49 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
writer.addDocument(config.build(taxoWriter, doc));
}
// Also index random content for more random testing:
String[] paths = new String[] {"a", "b", "c"};
int count = random().nextInt(1000);
randomIntValues = new HashMap<>();
randomFloatValues = new HashMap<>();
randomIntSingleValued = new HashMap<>();
randomFloatSingleValued = new HashMap<>();
for (int i = 0; i < count; i++) {
Document doc = new Document();
if (random().nextInt(10) >= 2) { // occasionally don't add any fields
// Add up to five ordinals + values for each doc. Note that duplicates are totally fine:
for (int j = 0; j < 5; j++) {
String path = paths[random().nextInt(3)];
if (random().nextBoolean()) { // maybe index an int association with the dim
int nextInt = atLeast(1);
randomIntValues.computeIfAbsent(path, k -> new ArrayList<>()).add(nextInt);
doc.add(new IntAssociationFacetField(nextInt, "int_random", path));
}
if (random().nextBoolean()) { // maybe index a float association with the dim
float nextFloat = random().nextFloat() * 10000f;
randomFloatValues.computeIfAbsent(path, k -> new ArrayList<>()).add(nextFloat);
doc.add(new FloatAssociationFacetField(nextFloat, "float_random", path));
}
}
// Also, (maybe) add to the single-valued association fields:
String path = paths[random().nextInt(3)];
if (random().nextBoolean()) {
int nextInt = atLeast(1);
randomIntSingleValued.computeIfAbsent(path, k -> new ArrayList<>()).add(nextInt);
doc.add(new IntAssociationFacetField(nextInt, "int_single_valued", path));
}
if (random().nextBoolean()) {
float nextFloat = random().nextFloat() * 10000f;
randomFloatSingleValued.computeIfAbsent(path, k -> new ArrayList<>()).add(nextFloat);
doc.add(new FloatAssociationFacetField(nextFloat, "float_single_valued", path));
}
}
writer.addDocument(config.build(taxoWriter, doc));
}
taxoWriter.close();
reader = writer.getReader();
writer.close();
@ -100,7 +160,9 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
IndexSearcher searcher = newSearcher(reader);
FacetsCollector fc = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
Facets facets = new TaxonomyFacetSumIntAssociations("$facets.int", taxoReader, config, fc);
Facets facets =
new TaxonomyFacetIntAssociations(
"$facets.int", taxoReader, config, fc, AssociationAggregationFunction.SUM);
assertEquals(
"dim=int path=[] value=-1 childCount=2\n a (200)\n b (150)\n",
facets.getTopChildren(10, "int").toString());
@ -110,11 +172,54 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
"Wrong count for category 'b'!", 150, facets.getSpecificValue("int", "b").intValue());
}
public void testIntAssociationRandom() throws Exception {
FacetsCollector fc = new FacetsCollector();
IndexSearcher searcher = newSearcher(reader);
searcher.search(new MatchAllDocsQuery(), fc);
Map<String, Integer> expected;
Facets facets;
// SUM:
facets =
new TaxonomyFacetIntAssociations(
"$facets.int", taxoReader, config, fc, AssociationAggregationFunction.SUM);
expected = new HashMap<>();
for (Map.Entry<String, List<Integer>> e : randomIntValues.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().reduce(Integer::sum).orElse(0));
}
validateInts("int_random", expected, AssociationAggregationFunction.SUM, true, facets);
expected = new HashMap<>();
for (Map.Entry<String, List<Integer>> e : randomIntSingleValued.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().reduce(Integer::sum).orElse(0));
}
validateInts("int_single_valued", expected, AssociationAggregationFunction.SUM, false, facets);
// MAX:
facets =
new TaxonomyFacetIntAssociations(
"$facets.int", taxoReader, config, fc, AssociationAggregationFunction.MAX);
expected = new HashMap<>();
for (Map.Entry<String, List<Integer>> e : randomIntValues.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().max(Integer::compareTo).orElse(0));
}
validateInts("int_random", expected, AssociationAggregationFunction.MAX, true, facets);
expected = new HashMap<>();
for (Map.Entry<String, List<Integer>> e : randomIntSingleValued.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().max(Integer::compareTo).orElse(0));
}
validateInts("int_single_valued", expected, AssociationAggregationFunction.MAX, false, facets);
}
public void testFloatSumAssociation() throws Exception {
IndexSearcher searcher = newSearcher(reader);
FacetsCollector fc = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
Facets facets = new TaxonomyFacetSumFloatAssociations("$facets.float", taxoReader, config, fc);
Facets facets =
new TaxonomyFacetFloatAssociations(
"$facets.float", taxoReader, config, fc, AssociationAggregationFunction.SUM);
assertEquals(
"dim=float path=[] value=-1.0 childCount=2\n a (50.0)\n b (9.999995)\n",
facets.getTopChildren(10, "float").toString());
@ -130,6 +235,49 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
0.00001);
}
public void testFloatAssociationRandom() throws Exception {
FacetsCollector fc = new FacetsCollector();
IndexSearcher searcher = newSearcher(reader);
searcher.search(new MatchAllDocsQuery(), fc);
Map<String, Float> expected;
Facets facets;
// SUM:
facets =
new TaxonomyFacetFloatAssociations(
"$facets.float", taxoReader, config, fc, AssociationAggregationFunction.SUM);
expected = new HashMap<>();
for (Map.Entry<String, List<Float>> e : randomFloatValues.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().reduce(Float::sum).orElse(0f));
}
validateFloats("float_random", expected, AssociationAggregationFunction.SUM, true, facets);
expected = new HashMap<>();
for (Map.Entry<String, List<Float>> e : randomFloatSingleValued.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().reduce(Float::sum).orElse(0f));
}
validateFloats(
"float_single_valued", expected, AssociationAggregationFunction.SUM, false, facets);
// MAX:
facets =
new TaxonomyFacetFloatAssociations(
"$facets.float", taxoReader, config, fc, AssociationAggregationFunction.MAX);
expected = new HashMap<>();
for (Map.Entry<String, List<Float>> e : randomFloatValues.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().max(Float::compareTo).orElse(0f));
}
validateFloats("float_random", expected, AssociationAggregationFunction.MAX, true, facets);
expected = new HashMap<>();
for (Map.Entry<String, List<Float>> e : randomFloatSingleValued.entrySet()) {
expected.put(e.getKey(), e.getValue().stream().max(Float::compareTo).orElse(0f));
}
validateFloats(
"float_single_valued", expected, AssociationAggregationFunction.MAX, false, facets);
}
/**
* Make sure we can test both int and float assocs in one index, as long as we send each to a
* different field.
@ -138,7 +286,9 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
IndexSearcher searcher = newSearcher(reader);
FacetsCollector fc = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
Facets facets = new TaxonomyFacetSumFloatAssociations("$facets.float", taxoReader, config, fc);
Facets facets =
new TaxonomyFacetFloatAssociations(
"$facets.float", taxoReader, config, fc, AssociationAggregationFunction.SUM);
assertEquals(
"Wrong count for category 'a'!",
50f,
@ -150,7 +300,9 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
facets.getSpecificValue("float", "b").floatValue(),
0.00001);
facets = new TaxonomyFacetSumIntAssociations("$facets.int", taxoReader, config, fc);
facets =
new TaxonomyFacetIntAssociations(
"$facets.int", taxoReader, config, fc, AssociationAggregationFunction.SUM);
assertEquals(
"Wrong count for category 'a'!", 200, facets.getSpecificValue("int", "a").intValue());
assertEquals(
@ -160,7 +312,9 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
public void testWrongIndexFieldName() throws Exception {
IndexSearcher searcher = newSearcher(reader);
FacetsCollector fc = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
Facets facets = new TaxonomyFacetSumFloatAssociations(taxoReader, config, fc);
Facets facets =
new TaxonomyFacetFloatAssociations(
"wrong_field", taxoReader, config, fc, AssociationAggregationFunction.SUM);
expectThrows(
IllegalArgumentException.class,
() -> {
@ -242,7 +396,9 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
q.add("int", "b");
FacetsCollector fc = searcher.search(q, new FacetsCollectorManager());
Facets facets = new TaxonomyFacetSumIntAssociations("$facets.int", taxoReader, config, fc);
Facets facets =
new TaxonomyFacetIntAssociations(
"$facets.int", taxoReader, config, fc, AssociationAggregationFunction.SUM);
assertEquals(
"dim=int path=[] value=-1 childCount=2\n b (150)\n a (100)\n",
facets.getTopChildren(10, "int").toString());
@ -251,4 +407,52 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase {
assertEquals(
"Wrong count for category 'b'!", 150, facets.getSpecificValue("int", "b").intValue());
}
private void validateInts(
String dim,
Map<String, Integer> expected,
AssociationAggregationFunction aggregationFunction,
boolean isMultiValued,
Facets facets)
throws IOException {
int aggregatedValue = 0;
for (Map.Entry<String, Integer> e : expected.entrySet()) {
int value = e.getValue();
assertEquals(value, facets.getSpecificValue(dim, e.getKey()).intValue());
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, value);
}
if (isMultiValued) {
aggregatedValue = -1;
}
FacetResult facetResult = facets.getTopChildren(10, dim);
assertEquals(dim, facetResult.dim);
assertEquals(aggregatedValue, facetResult.value.intValue());
assertEquals(expected.size(), facetResult.childCount);
}
private void validateFloats(
String dim,
Map<String, Float> expected,
AssociationAggregationFunction aggregationFunction,
boolean isMultiValued,
Facets facets)
throws IOException {
float aggregatedValue = 0f;
for (Map.Entry<String, Float> e : expected.entrySet()) {
float value = e.getValue();
assertEquals(value, facets.getSpecificValue(dim, e.getKey()).floatValue(), 1);
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, value);
}
if (isMultiValued) {
aggregatedValue = -1;
}
FacetResult facetResult = facets.getTopChildren(10, dim);
assertEquals(dim, facetResult.dim);
assertEquals(aggregatedValue, facetResult.value.floatValue(), 1);
assertEquals(expected.size(), facetResult.childCount);
}
}

View File

@ -16,6 +16,7 @@
*/
package org.apache.lucene.facet.taxonomy;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -55,7 +56,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.IOUtils;
public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
public class TestTaxonomyFacetValueSource extends FacetTestCase {
public void testBasic() throws Exception {
@ -111,20 +112,32 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
// Facets.search utility methods:
FacetsCollector c = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
TaxonomyFacetSumValueSource facets =
new TaxonomyFacetSumValueSource(
taxoReader, new FacetsConfig(), c, DoubleValuesSource.fromIntField("num"));
FacetsConfig facetsConfig = new FacetsConfig();
DoubleValuesSource valuesSource = DoubleValuesSource.fromIntField("num");
// Test SUM:
Facets facets =
new TaxonomyFacetFloatAssociations(
taxoReader, facetsConfig, c, AssociationAggregationFunction.SUM, valuesSource);
// Retrieve & verify results:
assertEquals(
"dim=Author path=[] value=145.0 childCount=4\n Lisa (50.0)\n Frank (45.0)\n Susan (40.0)\n Bob (10.0)\n",
facets.getTopChildren(10, "Author").toString());
// Test MAX:
facets =
new TaxonomyFacetFloatAssociations(
taxoReader, facetsConfig, c, AssociationAggregationFunction.MAX, valuesSource);
assertEquals(
"dim=Author path=[] value=45.0 childCount=4\n Frank (45.0)\n Susan (40.0)\n Lisa (30.0)\n Bob (10.0)\n",
facets.getTopChildren(10, "Author").toString());
// test getTopChildren(0, dim)
final Facets f = facets;
expectThrows(
IllegalArgumentException.class,
() -> {
facets.getTopChildren(0, "Author");
f.getTopChildren(0, "Author");
});
taxoReader.close();
@ -182,9 +195,13 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
FacetsCollector c = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
TaxonomyFacetSumValueSource facets =
new TaxonomyFacetSumValueSource(
taxoReader, new FacetsConfig(), c, DoubleValuesSource.fromIntField("num"));
Facets facets =
new TaxonomyFacetFloatAssociations(
taxoReader,
new FacetsConfig(),
c,
AssociationAggregationFunction.SUM,
DoubleValuesSource.fromIntField("num"));
// Ask for top 10 labels for any dims that have counts:
List<FacetResult> results = facets.getAllDims(10);
@ -262,9 +279,13 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
FacetsCollector c = searcher.search(new MatchAllDocsQuery(), new FacetsCollectorManager());
TaxonomyFacetSumValueSource facets =
new TaxonomyFacetSumValueSource(
taxoReader, config, c, DoubleValuesSource.fromIntField("num"));
Facets facets =
new TaxonomyFacetFloatAssociations(
taxoReader,
config,
c,
AssociationAggregationFunction.SUM,
DoubleValuesSource.fromIntField("num"));
// Ask for top 10 labels for any dims that have counts:
List<FacetResult> results = facets.getAllDims(10);
@ -288,7 +309,7 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
IOUtils.close(searcher.getIndexReader(), taxoReader, dir, taxoDir);
}
public void testSumScoreAggregator() throws Exception {
public void testScoreAggregator() throws Exception {
Directory indexDir = newDirectory();
Directory taxoDir = newDirectory();
@ -314,12 +335,20 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
TopDocs td = FacetsCollector.search(newSearcher(r), csq, 10, fc);
// Test SUM:
Facets facets =
new TaxonomyFacetSumValueSource(taxoReader, config, fc, DoubleValuesSource.SCORES);
new TaxonomyFacetFloatAssociations(
taxoReader, config, fc, AssociationAggregationFunction.SUM, DoubleValuesSource.SCORES);
int expected = (int) (csq.getBoost() * td.totalHits.value);
assertEquals(expected, facets.getSpecificValue("dim", "a").intValue());
// Test MAX:
facets =
new TaxonomyFacetFloatAssociations(
taxoReader, config, fc, AssociationAggregationFunction.MAX, DoubleValuesSource.SCORES);
expected = (int) csq.getBoost();
assertEquals(expected, facets.getSpecificValue("dim", "a").intValue());
iw.close();
IOUtils.close(taxoWriter, taxoReader, taxoDir, r, indexDir);
}
@ -343,12 +372,31 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
FacetsCollector sfc =
newSearcher(r).search(new MatchAllDocsQuery(), new FacetsCollectorManager());
// Test SUM:
Facets facets =
new TaxonomyFacetSumValueSource(
taxoReader, config, sfc, DoubleValuesSource.fromLongField("price"));
new TaxonomyFacetFloatAssociations(
taxoReader,
config,
sfc,
AssociationAggregationFunction.SUM,
DoubleValuesSource.fromLongField("price"));
assertEquals(
"dim=a path=[] value=10.0 childCount=2\n 1 (6.0)\n 0 (4.0)\n",
facets.getTopChildren(10, "a").toString());
// Test MAX:
facets =
new TaxonomyFacetFloatAssociations(
taxoReader,
config,
sfc,
AssociationAggregationFunction.MAX,
DoubleValuesSource.fromLongField("price"));
assertEquals(
"dim=a path=[] value=4.0 childCount=2\n 1 (4.0)\n 0 (3.0)\n",
facets.getTopChildren(10, "a").toString());
iw.close();
IOUtils.close(taxoWriter, taxoReader, taxoDir, r, indexDir);
}
@ -381,12 +429,23 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
// categories easier
Query q = new FunctionQuery(new LongFieldSource("price"));
FacetsCollector.search(newSearcher(r), q, 10, fc);
Facets facets =
new TaxonomyFacetSumValueSource(taxoReader, config, fc, DoubleValuesSource.SCORES);
// Test SUM:
Facets facets =
new TaxonomyFacetFloatAssociations(
taxoReader, config, fc, AssociationAggregationFunction.SUM, DoubleValuesSource.SCORES);
assertEquals(
"dim=a path=[] value=10.0 childCount=2\n 1 (6.0)\n 0 (4.0)\n",
facets.getTopChildren(10, "a").toString());
// Test MAX:
facets =
new TaxonomyFacetFloatAssociations(
taxoReader, config, fc, AssociationAggregationFunction.MAX, DoubleValuesSource.SCORES);
assertEquals(
"dim=a path=[] value=4.0 childCount=2\n 1 (4.0)\n 0 (3.0)\n",
facets.getTopChildren(10, "a").toString());
iw.close();
IOUtils.close(taxoWriter, taxoReader, taxoDir, r, indexDir);
}
@ -413,13 +472,31 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
FacetsCollector sfc =
newSearcher(r).search(new MatchAllDocsQuery(), new FacetsCollectorManager());
Facets facets =
new TaxonomyFacetSumValueSource(
taxoReader, config, sfc, DoubleValuesSource.fromLongField("price"));
// Test SUM:
Facets facets =
new TaxonomyFacetFloatAssociations(
taxoReader,
config,
sfc,
AssociationAggregationFunction.SUM,
DoubleValuesSource.fromLongField("price"));
assertEquals(
"dim=a path=[] value=10.0 childCount=2\n 1 (6.0)\n 0 (4.0)\n",
facets.getTopChildren(10, "a").toString());
// Test MAX:
facets =
new TaxonomyFacetFloatAssociations(
taxoReader,
config,
sfc,
AssociationAggregationFunction.MAX,
DoubleValuesSource.fromLongField("price"));
assertEquals(
"dim=a path=[] value=4.0 childCount=2\n 1 (4.0)\n 0 (3.0)\n",
facets.getTopChildren(10, "a").toString());
iw.close();
IOUtils.close(taxoWriter, taxoReader, taxoDir, r, indexDir);
}
@ -449,7 +526,13 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
Facets facets1 = getTaxonomyFacetCounts(taxoReader, config, fc);
Facets facets2 =
new TaxonomyFacetSumValueSource("$b", taxoReader, config, fc, DoubleValuesSource.SCORES);
new TaxonomyFacetFloatAssociations(
"$b",
taxoReader,
config,
fc,
AssociationAggregationFunction.SUM,
DoubleValuesSource.SCORES);
assertEquals(r.maxDoc(), facets1.getTopChildren(10, "a").value.intValue());
assertEquals(r.maxDoc(), facets2.getTopChildren(10, "b").value.doubleValue(), 1E-10);
@ -495,75 +578,107 @@ public class TestTaxonomyFacetSumValueSource extends FacetTestCase {
}
FacetsCollector fc = new FacetsCollector();
FacetsCollector.search(searcher, new TermQuery(new Term("content", searchToken)), 10, fc);
Facets facets =
new TaxonomyFacetSumValueSource(
tr, config, fc, DoubleValuesSource.fromFloatField("value"));
// Slow, yet hopefully bug-free, faceting:
@SuppressWarnings({"rawtypes", "unchecked"})
Map<String, Float>[] expectedValues = new HashMap[numDims];
for (int i = 0; i < numDims; i++) {
expectedValues[i] = new HashMap<>();
}
for (TestDoc doc : testDocs) {
if (doc.content.equals(searchToken)) {
for (int j = 0; j < numDims; j++) {
if (doc.dims[j] != null) {
Float v = expectedValues[j].get(doc.dims[j]);
if (v == null) {
expectedValues[j].put(doc.dims[j], doc.value);
} else {
expectedValues[j].put(doc.dims[j], v + doc.value);
}
}
}
}
}
List<FacetResult> expected = new ArrayList<>();
for (int i = 0; i < numDims; i++) {
List<LabelAndValue> labelValues = new ArrayList<>();
double totValue = 0;
for (Map.Entry<String, Float> ent : expectedValues[i].entrySet()) {
if (ent.getValue() > 0) {
labelValues.add(new LabelAndValue(ent.getKey(), ent.getValue()));
totValue += ent.getValue();
}
}
sortLabelValues(labelValues);
if (totValue > 0) {
expected.add(
new FacetResult(
"dim" + i,
new String[0],
totValue,
labelValues.toArray(new LabelAndValue[labelValues.size()]),
labelValues.size()));
}
}
// Sort by highest value, tie break by value:
sortFacetResults(expected);
List<FacetResult> actual = facets.getAllDims(10);
// test default implementation of getTopDims
if (actual.size() > 0) {
List<FacetResult> topDimsResults1 = facets.getTopDims(1, 10);
assertEquals(actual.get(0), topDimsResults1.get(0));
}
// Messy: fixup ties
sortTies(actual);
if (VERBOSE) {
System.out.println("expected=\n" + expected.toString());
System.out.println("actual=\n" + actual.toString());
}
assertFloatValuesEquals(expected, actual);
checkResults(
numDims,
testDocs,
searchToken,
tr,
config,
fc,
DoubleValuesSource.fromFloatField("value"),
AssociationAggregationFunction.SUM);
checkResults(
numDims,
testDocs,
searchToken,
tr,
config,
fc,
DoubleValuesSource.fromFloatField("value"),
AssociationAggregationFunction.MAX);
}
w.close();
IOUtils.close(tw, searcher.getIndexReader(), tr, indexDir, taxoDir);
}
private void checkResults(
int numDims,
List<TestDoc> testDocs,
String searchToken,
TaxonomyReader taxoReader,
FacetsConfig facetsConfig,
FacetsCollector facetsCollector,
DoubleValuesSource valuesSource,
AssociationAggregationFunction aggregationFunction)
throws IOException {
// Slow, yet hopefully bug-free, faceting:
@SuppressWarnings({"rawtypes", "unchecked"})
Map<String, Float>[] expectedValues = new HashMap[numDims];
for (int i = 0; i < numDims; i++) {
expectedValues[i] = new HashMap<>();
}
for (TestDoc doc : testDocs) {
if (doc.content.equals(searchToken)) {
for (int j = 0; j < numDims; j++) {
if (doc.dims[j] != null) {
Float v = expectedValues[j].get(doc.dims[j]);
if (v == null) {
expectedValues[j].put(doc.dims[j], doc.value);
} else {
float newValue = aggregationFunction.aggregate(v, doc.value);
expectedValues[j].put(doc.dims[j], newValue);
}
}
}
}
}
List<FacetResult> expected = new ArrayList<>();
for (int i = 0; i < numDims; i++) {
List<LabelAndValue> labelValues = new ArrayList<>();
float aggregatedValue = 0;
for (Map.Entry<String, Float> ent : expectedValues[i].entrySet()) {
labelValues.add(new LabelAndValue(ent.getKey(), ent.getValue()));
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, ent.getValue());
}
sortLabelValues(labelValues);
if (aggregatedValue > 0) {
expected.add(
new FacetResult(
"dim" + i,
new String[0],
aggregatedValue,
labelValues.toArray(new LabelAndValue[labelValues.size()]),
labelValues.size()));
}
}
// Sort by highest value, tie break by value:
sortFacetResults(expected);
Facets facets =
new TaxonomyFacetFloatAssociations(
taxoReader, facetsConfig, facetsCollector, aggregationFunction, valuesSource);
List<FacetResult> actual = facets.getAllDims(10);
// test default implementation of getTopDims
if (actual.size() > 0) {
List<FacetResult> topDimsResults1 = facets.getTopDims(1, 10);
assertEquals(actual.get(0), topDimsResults1.get(0));
}
// Messy: fixup ties
sortTies(actual);
if (VERBOSE) {
System.out.println("expected=\n" + expected.toString());
System.out.println("actual=\n" + actual.toString());
}
assertFloatValuesEquals(expected, actual);
}
}