diff --git a/lucene/facet/src/java/org/apache/lucene/facet/taxonomy/FloatTaxonomyFacets.java b/lucene/facet/src/java/org/apache/lucene/facet/taxonomy/FloatTaxonomyFacets.java index 60ec09e464a..342bb77714b 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/taxonomy/FloatTaxonomyFacets.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/taxonomy/FloatTaxonomyFacets.java @@ -17,12 +17,16 @@ package org.apache.lucene.facet.taxonomy; import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.lucene.facet.FacetResult; import org.apache.lucene.facet.FacetsConfig; import org.apache.lucene.facet.FacetsConfig.DimConfig; import org.apache.lucene.facet.LabelAndValue; import org.apache.lucene.facet.TopOrdAndFloatQueue; +import org.apache.lucene.util.PriorityQueue; /** Base class for all taxonomy-based facets that aggregate to a per-ords float[]. */ abstract class FloatTaxonomyFacets extends TaxonomyFacets { @@ -35,6 +39,9 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets { /** Per-ordinal value. */ final float[] values; + /** Pass in emptyPath for getTopDims and getAllDims. */ + private static final String[] emptyPath = new String[0]; + /** Sole constructor. */ FloatTaxonomyFacets( String indexFieldName, @@ -107,6 +114,23 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets { return null; } + ChildOrdsResult childOrdsResult = getChildOrdsResult(dimConfig, dimOrd, topN); + if (childOrdsResult.aggregatedValue == 0) { + return null; + } + + LabelAndValue[] labelValues = getLabelValues(childOrdsResult.q, cp.length); + return new FacetResult( + dim, path, childOrdsResult.aggregatedValue, labelValues, childOrdsResult.childCount); + } + + /** + * Return ChildOrdsResult that contains results of aggregatedValue, childCount, and the queue for + * the dimension's top children to populate FacetResult in getPathResult. + */ + private ChildOrdsResult getChildOrdsResult(DimConfig dimConfig, int dimOrd, int topN) + throws IOException { + TopOrdAndFloatQueue q = new TopOrdAndFloatQueue(Math.min(taxoReader.getSize(), topN)); float bottomValue = 0; @@ -138,10 +162,6 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets { ord = siblings[ord]; } - if (aggregatedValue == 0) { - return null; - } - if (dimConfig.multiValued) { if (dimConfig.requireDimCount) { aggregatedValue = values[dimOrd]; @@ -149,10 +169,17 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets { // Our sum'd count is not correct, in general: aggregatedValue = -1; } - } else { - // Our sum'd dim count is accurate, so we keep it } + return new ChildOrdsResult(aggregatedValue, childCount, q); + } + /** + * Return label and values for top dimensions and children + * + * @param q the queue for the dimension's top children + * @param pathLength the length of a dimension's children paths + */ + private LabelAndValue[] getLabelValues(TopOrdAndFloatQueue q, int pathLength) throws IOException { LabelAndValue[] labelValues = new LabelAndValue[q.size()]; int[] ordinals = new int[labelValues.length]; float[] values = new float[labelValues.length]; @@ -165,9 +192,151 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets { FacetLabel[] bulkPath = taxoReader.getBulkPath(ordinals); for (int i = 0; i < labelValues.length; i++) { - labelValues[i] = new LabelAndValue(bulkPath[i].components[cp.length], values[i]); + labelValues[i] = new LabelAndValue(bulkPath[i].components[pathLength], values[i]); + } + return labelValues; + } + + /** Return value of a dimension. */ + private float getDimValue( + FacetsConfig.DimConfig dimConfig, + String dim, + int dimOrd, + int topN, + HashMap dimToChildOrdsResult) + throws IOException { + + // if dimConfig.hierarchical == true || dim is multiValued and dim count has been aggregated at + // indexing time, return dimCount directly + if (dimConfig.hierarchical == true || (dimConfig.multiValued && dimConfig.requireDimCount)) { + return values[dimOrd]; } - return new FacetResult(dim, path, aggregatedValue, labelValues, childCount); + // if dimCount was not aggregated at indexing time, iterate over childOrds to get dimCount + ChildOrdsResult childOrdsResult = getChildOrdsResult(dimConfig, dimOrd, topN); + + // if no early termination, store dim and childOrdsResult into a hashmap to avoid calling + // getChildOrdsResult again in getTopDims + dimToChildOrdsResult.put(dim, childOrdsResult); + return childOrdsResult.aggregatedValue; + } + + @Override + public List getTopDims(int topNDims, int topNChildren) throws IOException { + validateTopN(topNDims); + validateTopN(topNChildren); + + // get existing children and siblings ordinal array from TaxonomyFacets + int[] children = getChildren(); + int[] siblings = getSiblings(); + + // Create priority queue to store top dimensions and sort by their aggregated values/hits and + // string values. + PriorityQueue pq = + new PriorityQueue<>(topNDims) { + @Override + protected boolean lessThan(DimValueResult a, DimValueResult b) { + if (a.value > b.value) { + return false; + } else if (a.value < b.value) { + return true; + } else { + return a.dim.compareTo(b.dim) > 0; + } + } + }; + + // create hashMap to store the ChildOrdsResult to avoid calling getChildOrdsResult for all dims + HashMap dimToChildOrdsResult = new HashMap<>(); + + // iterate over children and siblings ordinals for all dims + int ord = children[TaxonomyReader.ROOT_ORDINAL]; + while (ord != TaxonomyReader.INVALID_ORDINAL) { + String dim = taxoReader.getPath(ord).components[0]; + FacetsConfig.DimConfig dimConfig = config.getDimConfig(dim); + if (dimConfig.indexFieldName.equals(indexFieldName)) { + FacetLabel cp = new FacetLabel(dim, emptyPath); + int dimOrd = taxoReader.getOrdinal(cp); + float dimCount = 0; + // if dimOrd = -1, we skip this dim, else call getDimValue + if (dimOrd != -1) { + dimCount = getDimValue(dimConfig, dim, dimOrd, topNChildren, dimToChildOrdsResult); + if (dimCount != 0) { + // use priority queue to store DimValueResult for topNDims + if (pq.size() < topNDims) { + pq.add(new DimValueResult(dim, dimOrd, dimCount)); + } else { + if (dimCount > pq.top().value + || (dimCount == pq.top().value && dim.compareTo(pq.top().dim) < 0)) { + DimValueResult bottomDim = pq.top(); + bottomDim.dim = dim; + bottomDim.value = dimCount; + pq.updateTop(); + } + } + } + } + } + ord = siblings[ord]; + } + + // use fixed-size array to reduce space usage + FacetResult[] results = new FacetResult[pq.size()]; + + while (pq.size() > 0) { + DimValueResult dimValueResult = pq.pop(); + String dim = dimValueResult.dim; + ChildOrdsResult childOrdsResult; + // if the childOrdsResult was stored in the map, avoid calling getChildOrdsResult again + if (dimToChildOrdsResult.containsKey(dim)) { + childOrdsResult = dimToChildOrdsResult.get(dim); + } else { + FacetsConfig.DimConfig dimConfig = config.getDimConfig(dim); + childOrdsResult = getChildOrdsResult(dimConfig, dimValueResult.dimOrd, topNChildren); + } + // FacetResult requires String[] path, and path is always empty for getTopDims. + // pathLength is always equal to 1 when FacetLabel is constructed with + // FacetLabel(dim, emptyPath), and therefore, 1 is passed in when calling getLabelValues + FacetResult facetResult = + new FacetResult( + dimValueResult.dim, + emptyPath, + dimValueResult.value, + getLabelValues(childOrdsResult.q, 1), + childOrdsResult.childCount); + results[pq.size()] = facetResult; + } + return Arrays.asList(results); + } + + /** + * Create DimValueResult to store the label, dim ordinal and dim count of a dim in priority queue + */ + private static class DimValueResult { + String dim; + int dimOrd; + float value; + + DimValueResult(String dim, int dimOrd, float value) { + this.dim = dim; + this.dimOrd = dimOrd; + this.value = value; + } + } + + /** + * Create ChildOrdsResult to store dimCount, childCount, and the queue for the dimension's top + * children + */ + private static class ChildOrdsResult { + final float aggregatedValue; + final int childCount; + final TopOrdAndFloatQueue q; + + ChildOrdsResult(float aggregatedValue, int childCount, TopOrdAndFloatQueue q) { + this.aggregatedValue = aggregatedValue; + this.childCount = childCount; + this.q = q; + } } } diff --git a/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetAssociations.java b/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetAssociations.java index 9cb9963e315..ab4f02639f6 100644 --- a/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetAssociations.java +++ b/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetAssociations.java @@ -202,6 +202,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase { "Wrong count for category 'a'!", 200, facets.getSpecificValue("int", "a").intValue()); assertEquals( "Wrong count for category 'b'!", 150, facets.getSpecificValue("int", "b").intValue()); + + // test getAllDims and getTopDims + List topDims = facets.getTopDims(10, 10); + List allDims = facets.getAllDims(10); + assertEquals(topDims, allDims); } public void testIntAssociationRandom() throws Exception { @@ -229,6 +234,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase { } validateInts("int_single_valued", expected, AssociationAggregationFunction.SUM, false, facets); + // test getAllDims and getTopDims + List allDims = facets.getAllDims(10); + List topDims = facets.getTopDims(10, 10); + assertEquals(topDims, allDims); + // MAX: facets = new TaxonomyFacetIntAssociations( @@ -243,6 +253,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase { expected.put(e.getKey(), e.getValue().stream().max(Integer::compareTo).orElse(0)); } validateInts("int_single_valued", expected, AssociationAggregationFunction.MAX, false, facets); + + // test getAllDims and getTopDims + topDims = facets.getTopDims(10, 10); + allDims = facets.getAllDims(10); + assertEquals(topDims, allDims); } public void testFloatSumAssociation() throws Exception { @@ -265,6 +280,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase { 10f, facets.getSpecificValue("float", "b").floatValue(), 0.00001); + + // test getAllDims and getTopDims + List topDims = facets.getTopDims(10, 10); + List allDims = facets.getAllDims(10); + assertEquals(topDims, allDims); } public void testFloatAssociationRandom() throws Exception { @@ -293,6 +313,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase { validateFloats( "float_single_valued", expected, AssociationAggregationFunction.SUM, false, facets); + // test getAllDims and getTopDims + List topDims = facets.getTopDims(10, 10); + List allDims = facets.getAllDims(10); + assertEquals(topDims, allDims); + // MAX: facets = new TaxonomyFacetFloatAssociations( @@ -308,6 +333,11 @@ public class TestTaxonomyFacetAssociations extends FacetTestCase { } validateFloats( "float_single_valued", expected, AssociationAggregationFunction.MAX, false, facets); + + // test getAllDims and getTopDims + topDims = facets.getTopDims(10, 10); + allDims = facets.getAllDims(10); + assertEquals(topDims, allDims); } /** diff --git a/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetValueSource.java b/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetValueSource.java index 3e5689e7815..2e454554317 100644 --- a/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetValueSource.java +++ b/lucene/facet/src/test/org/apache/lucene/facet/taxonomy/TestTaxonomyFacetValueSource.java @@ -236,8 +236,11 @@ public class TestTaxonomyFacetValueSource extends FacetTestCase { assertEquals(results, allDimsResults); // test getTopDims(0, 1) - List topDimsResults2 = facets.getTopDims(0, 1); - assertEquals(0, topDimsResults2.size()); + expectThrows( + IllegalArgumentException.class, + () -> { + facets.getTopDims(0, 1); + }); // test getTopDims(1, 0) with topNChildren = 0 expectThrows(