SOLR-2665: Added post group faceting.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1154676 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Martijn van Groningen 2011-08-07 09:02:33 +00:00
parent 01163c42c4
commit 0bc43983c4
6 changed files with 270 additions and 21 deletions

View File

@ -134,8 +134,8 @@ public abstract class AbstractAllGroupHeadsCollector<GH extends AbstractAllGroup
*/ */
protected class TemporalResult { protected class TemporalResult {
protected GH groupHead; public GH groupHead;
protected boolean stop; public boolean stop;
} }

View File

@ -360,6 +360,10 @@ New Features
* SOLR-2637: Added support for group result parsing in SolrJ. * SOLR-2637: Added support for group result parsing in SolrJ.
(Tao Cheng, Martijn van Groningen) (Tao Cheng, Martijn van Groningen)
* SOLR-2665: Added post group faceting. Facet counts are based on the most
relevant document of each group matching the query. This feature has the
same impact on the StatsComponent. (Martijn van Groningen)
Optimizations Optimizations
---------------------- ----------------------

View File

@ -320,6 +320,7 @@ public class QueryComponent extends SearchComponent
String[] queries = params.getParams(GroupParams.GROUP_QUERY); String[] queries = params.getParams(GroupParams.GROUP_QUERY);
String groupSortStr = params.get(GroupParams.GROUP_SORT); String groupSortStr = params.get(GroupParams.GROUP_SORT);
boolean main = params.getBool(GroupParams.GROUP_MAIN, false); boolean main = params.getBool(GroupParams.GROUP_MAIN, false);
boolean truncateGroups = params.getBool(GroupParams.GROUP_TRUNCATE, false);
String formatStr = params.get(GroupParams.GROUP_FORMAT, Grouping.Format.grouped.name()); String formatStr = params.get(GroupParams.GROUP_FORMAT, Grouping.Format.grouped.name());
Grouping.Format defaultFormat; Grouping.Format defaultFormat;
@ -346,7 +347,8 @@ public class QueryComponent extends SearchComponent
.setLimitDefault(limitDefault) .setLimitDefault(limitDefault)
.setDefaultTotalCount(defaultTotalCount) .setDefaultTotalCount(defaultTotalCount)
.setDocsPerGroupDefault(docsPerGroupDefault) .setDocsPerGroupDefault(docsPerGroupDefault)
.setGroupOffsetDefault(groupOffsetDefault); .setGroupOffsetDefault(groupOffsetDefault)
.setGetGroupedDocSet(truncateGroups);
if (fields != null) { if (fields != null) {
for (String field : fields) { for (String field : fields) {

View File

@ -28,12 +28,16 @@ import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.search.*; import org.apache.lucene.search.*;
import org.apache.lucene.search.grouping.*; import org.apache.lucene.search.grouping.*;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.OpenBitSet;
import org.apache.lucene.util.mutable.MutableValue; import org.apache.lucene.util.mutable.MutableValue;
import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap; import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.schema.*; import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.schema.StrFieldSource;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -69,6 +73,7 @@ public class Grouping {
private int maxDoc; private int maxDoc;
private boolean needScores; private boolean needScores;
private boolean getDocSet; private boolean getDocSet;
private boolean getGroupedDocSet;
private boolean getDocList; // doclist needed for debugging or highlighting private boolean getDocList; // doclist needed for debugging or highlighting
private Query query; private Query query;
private DocSet filter; private DocSet filter;
@ -154,7 +159,7 @@ public class Grouping {
Query q = parser.getQuery(); Query q = parser.getQuery();
final Grouping.Command gc; final Grouping.Command gc;
if (q instanceof FunctionQuery) { if (q instanceof FunctionQuery) {
ValueSource valueSource = ((FunctionQuery)q).getValueSource(); ValueSource valueSource = ((FunctionQuery) q).getValueSource();
if (valueSource instanceof StrFieldSource) { if (valueSource instanceof StrFieldSource) {
String field = ((StrFieldSource) valueSource).getField(); String field = ((StrFieldSource) valueSource).getField();
CommandField commandField = new CommandField(); CommandField commandField = new CommandField();
@ -255,6 +260,11 @@ public class Grouping {
return this; return this;
} }
public Grouping setGetGroupedDocSet(boolean getGroupedDocSet) {
this.getGroupedDocSet = getGroupedDocSet;
return this;
}
public List<Command> getCommands() { public List<Command> getCommands() {
return commands; return commands;
} }
@ -296,16 +306,21 @@ public class Grouping {
cmd.prepare(); cmd.prepare();
} }
AbstractAllGroupHeadsCollector<?> allGroupHeadsCollector = null;
List<Collector> collectors = new ArrayList<Collector>(commands.size()); List<Collector> collectors = new ArrayList<Collector>(commands.size());
for (Command cmd : commands) { for (Command cmd : commands) {
Collector collector = cmd.createFirstPassCollector(); Collector collector = cmd.createFirstPassCollector();
if (collector != null) if (collector != null) {
collectors.add(collector); collectors.add(collector);
} }
if (getGroupedDocSet && allGroupHeadsCollector == null) {
collectors.add(allGroupHeadsCollector = cmd.createAllGroupCollector());
}
}
Collector allCollectors = MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()])); Collector allCollectors = MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()]));
DocSetCollector setCollector = null; DocSetCollector setCollector = null;
if (getDocSet) { if (getDocSet && allGroupHeadsCollector == null) {
setCollector = new DocSetDelegateCollector(maxDoc >> 6, maxDoc, allCollectors); setCollector = new DocSetDelegateCollector(maxDoc >> 6, maxDoc, allCollectors);
allCollectors = setCollector; allCollectors = setCollector;
} }
@ -329,7 +344,12 @@ public class Grouping {
searcher.search(query, luceneFilter, allCollectors); searcher.search(query, luceneFilter, allCollectors);
} }
if (getDocSet) { if (getGroupedDocSet && allGroupHeadsCollector != null) {
FixedBitSet fixedBitSet = allGroupHeadsCollector.retrieveGroupHeads(maxDoc);
long[] bits = fixedBitSet.getBits();
OpenBitSet openBitSet = new OpenBitSet(bits, bits.length);
qr.setDocSet(new BitDocSet(openBitSet));
} else if (getDocSet) {
qr.setDocSet(setCollector.getDocSet()); qr.setDocSet(setCollector.getDocSet());
} }
@ -483,6 +503,17 @@ public class Grouping {
return null; return null;
} }
/**
* Returns a collector that is able to return the most relevant document of all groups.
* Returns <code>null</code> if the command doesn't support this type of collector.
*
* @return a collector that is able to return the most relevant document of all groups.
* @throws IOException If I/O related errors occur
*/
public AbstractAllGroupHeadsCollector<?> createAllGroupCollector() throws IOException {
return null;
}
/** /**
* Performs any necessary post actions to prepare the response. * Performs any necessary post actions to prepare the response.
* *
@ -585,7 +616,8 @@ public class Grouping {
} }
} }
int len = docsGathered - offset;int[] docs = ArrayUtils.toPrimitive(ids.toArray(new Integer[ids.size()])); int len = docsGathered - offset;
int[] docs = ArrayUtils.toPrimitive(ids.toArray(new Integer[ids.size()]));
float[] docScores = ArrayUtils.toPrimitive(scores.toArray(new Float[scores.size()])); float[] docScores = ArrayUtils.toPrimitive(scores.toArray(new Float[scores.size()]));
DocSlice docSlice = new DocSlice(offset, len, docs, docScores, getMatches(), maxScore); DocSlice docSlice = new DocSlice(offset, len, docs, docScores, getMatches(), maxScore);
@ -672,6 +704,15 @@ public class Grouping {
} }
} }
/**
* {@inheritDoc}
*/
@Override
public AbstractAllGroupHeadsCollector<?> createAllGroupCollector() throws IOException {
Sort sortWithinGroup = groupSort != null ? groupSort : new Sort();
return TermAllGroupHeadsCollector.create(groupBy, sortWithinGroup);
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@ -873,6 +914,12 @@ public class Grouping {
} }
} }
@Override
public AbstractAllGroupHeadsCollector<?> createAllGroupCollector() throws IOException {
Sort sortWithinGroup = groupSort != null ? groupSort : new Sort();
return new FunctionAllGroupHeadsCollector(groupBy, context, sortWithinGroup);
}
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@ -1091,4 +1138,102 @@ public class Grouping {
} }
static class FunctionAllGroupHeadsCollector extends AbstractAllGroupHeadsCollector<FunctionAllGroupHeadsCollector.GroupHead> {
private final ValueSource groupBy;
private final Map vsContext;
private final Map<MutableValue, GroupHead> groups;
private final Sort sortWithinGroup;
private DocValues docValues;
private DocValues.ValueFiller filler;
private MutableValue mval;
private AtomicReaderContext readerContext;
private Scorer scorer;
FunctionAllGroupHeadsCollector(ValueSource groupBy, Map vsContext, Sort sortWithinGroup) {
super(sortWithinGroup.getSort().length);
groups = new HashMap<MutableValue, GroupHead>();
this.sortWithinGroup = sortWithinGroup;
this.groupBy = groupBy;
this.vsContext = vsContext;
final SortField[] sortFields = sortWithinGroup.getSort();
for (int i = 0; i < sortFields.length; i++) {
reversed[i] = sortFields[i].getReverse() ? -1 : 1;
}
}
protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
filler.fillValue(doc);
GroupHead groupHead = groups.get(mval);
if (groupHead == null) {
MutableValue groupValue = mval.duplicate();
groupHead = new GroupHead(groupValue, sortWithinGroup, doc);
groups.put(groupValue, groupHead);
temporalResult.stop = true;
} else {
temporalResult.stop = false;
}
this.temporalResult.groupHead = groupHead;
}
protected Collection<GroupHead> getCollectedGroupHeads() {
return groups.values();
}
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
for (GroupHead groupHead : groups.values()) {
for (FieldComparator comparator : groupHead.comparators) {
comparator.setScorer(scorer);
}
}
}
public void setNextReader(AtomicReaderContext context) throws IOException {
this.readerContext = context;
docValues = groupBy.getValues(vsContext, context);
filler = docValues.getValueFiller();
mval = filler.getValue();
for (GroupHead groupHead : groups.values()) {
for (int i = 0; i < groupHead.comparators.length; i++) {
groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context);
}
}
}
class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<MutableValue> {
final FieldComparator[] comparators;
private GroupHead(MutableValue groupValue, Sort sort, int doc) throws IOException {
super(groupValue, doc + readerContext.docBase);
final SortField[] sortFields = sort.getSort();
comparators = new FieldComparator[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
comparators[i] = sortFields[i].getComparator(1, i).setNextReader(readerContext);
comparators[i].setScorer(scorer);
comparators[i].copy(0, doc);
comparators[i].setBottom(0);
}
}
public int compare(int compIDX, int doc) throws IOException {
return comparators[compIDX].compareBottom(doc);
}
public void updateDocHead(int doc) throws IOException {
for (FieldComparator comparator : comparators) {
comparator.copy(0, doc);
comparator.setBottom(0);
}
this.doc = doc + readerContext.docBase;
}
}
}
} }

View File

@ -234,6 +234,53 @@ public class TestGroupingSearch extends SolrTestCaseJ4 {
); );
} }
@Test
public void testGroupingGroupedBasedFaceting() throws Exception {
assertU(add(doc("id", "1", "value1_s1", "1", "value2_i", "1", "value3_s1", "a", "value4_i", "1")));
assertU(add(doc("id", "2", "value1_s1", "1", "value2_i", "2", "value3_s1", "a", "value4_i", "1")));
assertU(commit());
assertU(add(doc("id", "3", "value1_s1", "2", "value2_i", "3", "value3_s1", "b", "value4_i", "2")));
assertU(add(doc("id", "4", "value1_s1", "1", "value2_i", "4", "value3_s1", "a", "value4_i", "1")));
assertU(add(doc("id", "5", "value1_s1", "2", "value2_i", "5", "value3_s1", "b", "value4_i", "2")));
assertU(commit());
// Facet counts based on documents
SolrQueryRequest req = req("q", "*:*", "sort", "value2_i asc", "rows", "1", "group", "true", "group.field",
"value1_s1", "fl", "id", "facet", "true", "facet.field", "value3_s1", "group.truncate", "false");
assertJQ(
req,
"/grouped=={'value1_s1':{'matches':5,'groups':[{'groupValue':'1','doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}",
"/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',3,'b',2]},'facet_dates':{},'facet_ranges':{}}"
);
// Facet counts based on groups
req = req("q", "*:*", "sort", "value2_i asc", "rows", "1", "group", "true", "group.field",
"value1_s1", "fl", "id", "facet", "true", "facet.field", "value3_s1", "group.truncate", "true");
assertJQ(
req,
"/grouped=={'value1_s1':{'matches':5,'groups':[{'groupValue':'1','doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}",
"/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',1,'b',1]},'facet_dates':{},'facet_ranges':{}}"
);
// Facet counts based on groups and with group.func. This should trigger FunctionAllGroupHeadsCollector
req = req("q", "*:*", "sort", "value2_i asc", "rows", "1", "group", "true", "group.func",
"strdist(1,value1_s1,edit)", "fl", "id", "facet", "true", "facet.field", "value3_s1", "group.truncate", "true");
assertJQ(
req,
"/grouped=={'strdist(1,value1_s1,edit)':{'matches':5,'groups':[{'groupValue':1.0,'doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}",
"/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',1,'b',1]},'facet_dates':{},'facet_ranges':{}}"
);
// Facet counts based on groups without sort on an int field.
req = req("q", "*:*", "rows", "1", "group", "true", "group.field", "value4_i", "fl", "id", "facet", "true",
"facet.field", "value3_s1", "group.truncate", "true");
assertJQ(
req,
"/grouped=={'value4_i':{'matches':5,'groups':[{'groupValue':1,'doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}",
"/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',1,'b',1]},'facet_dates':{},'facet_ranges':{}}"
);
}
static String f = "foo_i"; static String f = "foo_i";
static String f2 = "foo2_i"; static String f2 = "foo2_i";
@ -474,7 +521,7 @@ public class TestGroupingSearch extends SolrTestCaseJ4 {
types.add(new FldType("id",ONE_ONE, new SVal('A','Z',4,4))); types.add(new FldType("id",ONE_ONE, new SVal('A','Z',4,4)));
types.add(new FldType("score_f",ONE_ONE, new FVal(1,100))); // field used to score types.add(new FldType("score_f",ONE_ONE, new FVal(1,100))); // field used to score
types.add(new FldType("foo_i",ZERO_ONE, new IRange(0,indexSize))); types.add(new FldType("foo_i",ZERO_ONE, new IRange(0,indexSize)));
types.add(new FldType(FOO_STRING_FIELD,ZERO_ONE, new SVal('a','z',1,2))); types.add(new FldType(FOO_STRING_FIELD,ONE_ONE, new SVal('a','z',1,2)));
types.add(new FldType(SMALL_STRING_FIELD,ZERO_ONE, new SVal('a',(char)('c'+indexSize/10),1,1))); types.add(new FldType(SMALL_STRING_FIELD,ZERO_ONE, new SVal('a',(char)('c'+indexSize/10),1,1)));
types.add(new FldType(SMALL_INT_FIELD,ZERO_ONE, new IRange(0,5+indexSize/10))); types.add(new FldType(SMALL_INT_FIELD,ZERO_ONE, new IRange(0,5+indexSize/10)));
@ -567,25 +614,61 @@ public class TestGroupingSearch extends SolrTestCaseJ4 {
for (Grp grp : groups.values()) grp.setMaxDoc(sortComparator); for (Grp grp : groups.values()) grp.setMaxDoc(sortComparator);
} }
List<Grp> sortedGroups = new ArrayList(groups.values()); List<Grp> sortedGroups = new ArrayList<Grp>(groups.values());
Collections.sort(sortedGroups, groupComparator==sortComparator ? createFirstDocComparator(sortComparator) : createMaxDocComparator(sortComparator)); Collections.sort(sortedGroups, groupComparator==sortComparator ? createFirstDocComparator(sortComparator) : createMaxDocComparator(sortComparator));
boolean includeNGroups = random.nextBoolean(); boolean includeNGroups = random.nextBoolean();
Object modelResponse = buildGroupedResult(h.getCore().getSchema(), sortedGroups, start, rows, group_offset, group_limit, includeNGroups); Object modelResponse = buildGroupedResult(h.getCore().getSchema(), sortedGroups, start, rows, group_offset, group_limit, includeNGroups);
boolean truncateGroups = random.nextBoolean();
Map<String, Integer> facetCounts = new TreeMap<String, Integer>();
if (truncateGroups) {
for (Grp grp : sortedGroups) {
Doc doc = grp.docs.get(0);
if (doc.getValues(FOO_STRING_FIELD) == null) {
continue;
}
String key = doc.getFirstValue(FOO_STRING_FIELD).toString();
boolean exists = facetCounts.containsKey(key);
int count = exists ? facetCounts.get(key) : 0;
facetCounts.put(key, ++count);
}
} else {
for (Doc doc : model.values()) {
if (doc.getValues(FOO_STRING_FIELD) == null) {
continue;
}
for (Comparable field : doc.getValues(FOO_STRING_FIELD)) {
String key = field.toString();
boolean exists = facetCounts.containsKey(key);
int count = exists ? facetCounts.get(key) : 0;
facetCounts.put(key, ++count);
}
}
}
List<Comparable> expectedFacetResponse = new ArrayList<Comparable>();
for (Map.Entry<String, Integer> stringIntegerEntry : facetCounts.entrySet()) {
expectedFacetResponse.add(stringIntegerEntry.getKey());
expectedFacetResponse.add(stringIntegerEntry.getValue());
}
int randomPercentage = random.nextInt(101); int randomPercentage = random.nextInt(101);
// TODO: create a random filter too // TODO: create a random filter too
SolrQueryRequest req = req("group","true","wt","json","indent","true", "echoParams","all", "q","{!func}score_f", "group.field",groupField SolrQueryRequest req = req("group","true","wt","json","indent","true", "echoParams","all", "q","{!func}score_f", "group.field",groupField
,sortStr==null ? "nosort":"sort", sortStr ==null ? "": sortStr ,sortStr==null ? "nosort":"sort", sortStr ==null ? "": sortStr
,(groupSortStr==null || groupSortStr==sortStr) ? "noGroupsort":"group.sort", groupSortStr==null ? "": groupSortStr ,(groupSortStr == null || groupSortStr == sortStr) ? "noGroupsort":"group.sort", groupSortStr==null ? "": groupSortStr
,"rows",""+rows, "start",""+start, "group.offset",""+group_offset, "group.limit",""+group_limit, ,"rows",""+rows, "start",""+start, "group.offset",""+group_offset, "group.limit",""+group_limit,
GroupParams.GROUP_CACHE_PERCENTAGE, Integer.toString(randomPercentage), GroupParams.GROUP_TOTAL_COUNT, includeNGroups ? "true" : "false" GroupParams.GROUP_CACHE_PERCENTAGE, Integer.toString(randomPercentage), GroupParams.GROUP_TOTAL_COUNT, includeNGroups ? "true" : "false",
"facet", "true", "facet.sort", "index", "facet.limit", "-1", "facet.field", FOO_STRING_FIELD,
GroupParams.GROUP_TRUNCATE, truncateGroups ? "true" : "false", "facet.mincount", "1", "facet.method", "fcs" // to avoid FC insanity
); );
String strResponse = h.query(req); String strResponse = h.query(req);
Object realResponse = ObjectBuilder.fromJSON(strResponse); Object realResponse = ObjectBuilder.fromJSON(strResponse);
String err = JSONTestUtil.matchObj("/grouped/"+groupField, realResponse, modelResponse); String err = JSONTestUtil.matchObj("/grouped/" + groupField, realResponse, modelResponse);
if (err != null) { if (err != null) {
log.error("GROUPING MISMATCH: " + err log.error("GROUPING MISMATCH: " + err
+ "\n\trequest="+req + "\n\trequest="+req
@ -599,6 +682,20 @@ public class TestGroupingSearch extends SolrTestCaseJ4 {
fail(err); fail(err);
} }
// assert post / pre grouping facets
err = JSONTestUtil.matchObj("/facet_counts/facet_fields/"+FOO_STRING_FIELD, realResponse, expectedFacetResponse);
if (err != null) {
log.error("GROUPING MISMATCH: " + err
+ "\n\trequest="+req
+ "\n\tresult="+strResponse
+ "\n\texpected="+ JSONUtil.toJSON(expectedFacetResponse)
);
// re-execute the request... good for putting a breakpoint here for debugging
h.query(req);
fail(err);
}
} // end query iter } // end query iter
} // end index iter } // end index iter

View File

@ -48,8 +48,9 @@ public interface GroupParams {
// Note: Since you can supply multiple fields to group on, but only have a facets for the whole result. It only makes // Note: Since you can supply multiple fields to group on, but only have a facets for the whole result. It only makes
// sense to me to support these parameters for the first group. // sense to me to support these parameters for the first group.
/** Whether the docSet (for example for faceting) should be based on plain documents (a.k.a UNGROUPED) or on the groups (a.k.a GROUPED). */ /** Whether the docSet (for example for faceting) should be based on plain documents (a.k.a UNGROUPED) or on the groups (a.k.a GROUPED).
public static final String GROUP_COLLAPSE = GROUP + ".collapse"; * The docSet will only the most relevant documents per group. It is if you query for everything with group.limit=1 */
public static final String GROUP_TRUNCATE = GROUP + ".truncate";
/** Whether the group count should be included in the response. */ /** Whether the group count should be included in the response. */
public static final String GROUP_TOTAL_COUNT = GROUP + ".ngroups"; public static final String GROUP_TOTAL_COUNT = GROUP + ".ngroups";