SOLR-5536: Add ValueSource collapse criteria to CollapsingQParserPlugin

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1554523 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Joel Bernstein 2013-12-31 14:35:48 +00:00
parent 7d7ef806b3
commit 43535fecb8
6 changed files with 264 additions and 12 deletions

View File

@ -134,6 +134,8 @@ New Features
* SOLR-5581: Give ZkCLI the ability to get files. (Gregory Chanan via Mark Miller) * SOLR-5581: Give ZkCLI the ability to get files. (Gregory Chanan via Mark Miller)
* SOLR-5536: Add ValueSource collapse criteria to CollapsingQParsingPlugin (Joel Bernstein)
Bug Fixes Bug Fixes
---------------------- ----------------------

View File

@ -17,7 +17,12 @@
package org.apache.solr.search; package org.apache.solr.search;
import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.schema.TrieFloatField; import org.apache.solr.schema.TrieFloatField;
import org.apache.solr.schema.TrieIntField; import org.apache.solr.schema.TrieIntField;
import org.apache.solr.schema.TrieLongField; import org.apache.solr.schema.TrieLongField;
@ -47,6 +52,7 @@ import java.util.Arrays;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.HashSet; import java.util.HashSet;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Iterator; import java.util.Iterator;
@ -242,7 +248,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
SchemaField schemaField = schema.getField(this.field); SchemaField schemaField = schema.getField(this.field);
SortedDocValues docValues = null; SortedDocValues docValues = null;
FunctionQuery funcQuery = null;
if(schemaField.hasDocValues()) { if(schemaField.hasDocValues()) {
docValues = searcher.getAtomicReader().getSortedDocValues(this.field); docValues = searcher.getAtomicReader().getSortedDocValues(this.field);
} else { } else {
@ -252,11 +258,39 @@ public class CollapsingQParserPlugin extends QParserPlugin {
FieldType fieldType = null; FieldType fieldType = null;
if(this.max != null) { if(this.max != null) {
fieldType = searcher.getSchema().getField(this.max).getType(); if(this.max.indexOf("(") == -1) {
fieldType = searcher.getSchema().getField(this.max).getType();
} else {
LocalSolrQueryRequest request = null;
try {
SolrParams params = new ModifiableSolrParams();
request = new LocalSolrQueryRequest(searcher.getCore(), params);
FunctionQParser functionQParser = new FunctionQParser(this.max, null, null,request);
funcQuery = (FunctionQuery)functionQParser.parse();
} catch (Exception e) {
throw new IOException(e);
} finally {
request.close();
}
}
} }
if(this.min != null) { if(this.min != null) {
fieldType = searcher.getSchema().getField(this.min).getType(); if(this.min.indexOf("(") == -1) {
fieldType = searcher.getSchema().getField(this.min).getType();
} else {
LocalSolrQueryRequest request = null;
try {
SolrParams params = new ModifiableSolrParams();
request = new LocalSolrQueryRequest(searcher.getCore(), params);
FunctionQParser functionQParser = new FunctionQParser(this.min, null, null,request);
funcQuery = (FunctionQuery)functionQParser.parse();
} catch (Exception e) {
throw new IOException(e);
} finally {
request.close();
}
}
} }
int maxDoc = searcher.maxDoc(); int maxDoc = searcher.maxDoc();
@ -274,7 +308,8 @@ public class CollapsingQParserPlugin extends QParserPlugin {
max != null, max != null,
this.needsScores, this.needsScores,
fieldType, fieldType,
boostDocs); boostDocs,
funcQuery);
} else { } else {
return new CollapsingScoreCollector(maxDoc, leafCount, docValues, this.nullPolicy, boostDocs); return new CollapsingScoreCollector(maxDoc, leafCount, docValues, this.nullPolicy, boostDocs);
} }
@ -508,7 +543,8 @@ public class CollapsingQParserPlugin extends QParserPlugin {
boolean max, boolean max,
boolean needsScores, boolean needsScores,
FieldType fieldType, FieldType fieldType,
IntOpenHashSet boostDocs) throws IOException{ IntOpenHashSet boostDocs,
FunctionQuery funcQuery) throws IOException{
this.maxDoc = maxDoc; this.maxDoc = maxDoc;
this.contexts = new AtomicReaderContext[segments]; this.contexts = new AtomicReaderContext[segments];
@ -517,14 +553,18 @@ public class CollapsingQParserPlugin extends QParserPlugin {
this.nullPolicy = nullPolicy; this.nullPolicy = nullPolicy;
this.needsScores = needsScores; this.needsScores = needsScores;
this.boostDocs = boostDocs; this.boostDocs = boostDocs;
if(fieldType instanceof TrieIntField) { if(funcQuery != null) {
this.fieldValueCollapse = new IntValueCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs); this.fieldValueCollapse = new ValueSourceCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs, funcQuery);
} else if(fieldType instanceof TrieLongField) {
this.fieldValueCollapse = new LongValueCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs);
} else if(fieldType instanceof TrieFloatField) {
this.fieldValueCollapse = new FloatValueCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs);
} else { } else {
throw new IOException("min/max must be either TrieInt, TrieLong or TrieFloat."); if(fieldType instanceof TrieIntField) {
this.fieldValueCollapse = new IntValueCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs);
} else if(fieldType instanceof TrieLongField) {
this.fieldValueCollapse = new LongValueCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs);
} else if(fieldType instanceof TrieFloatField) {
this.fieldValueCollapse = new FloatValueCollapse(maxDoc, field, nullPolicy, new int[valueCount], max, this.needsScores, boostDocs);
} else {
throw new IOException("min/max must be either TrieInt, TrieLong or TrieFloat.");
}
} }
} }
@ -877,6 +917,97 @@ public class CollapsingQParserPlugin extends QParserPlugin {
} }
} }
private class ValueSourceCollapse extends FieldValueCollapse {
private FloatCompare comp;
private float nullVal;
private ValueSource valueSource;
private FunctionValues functionValues;
private float[] ordVals;
private Map rcontext = new HashMap();
private CollapseScore collapseScore = new CollapseScore();
private float score;
private boolean cscore;
public ValueSourceCollapse(int maxDoc,
String funcStr,
int nullPolicy,
int[] ords,
boolean max,
boolean needsScores,
IntOpenHashSet boostDocs,
FunctionQuery funcQuery) throws IOException {
super(maxDoc, null, nullPolicy, max, needsScores, boostDocs);
this.valueSource = funcQuery.getValueSource();
this.ords = ords;
this.ordVals = new float[ords.length];
Arrays.fill(ords, -1);
if(max) {
comp = new MaxFloatComp();
Arrays.fill(ordVals, -Float.MAX_VALUE );
} else {
this.nullVal = Float.MAX_VALUE;
comp = new MinFloatComp();
Arrays.fill(ordVals, Float.MAX_VALUE);
}
if(funcStr.indexOf("cscore()") != -1) {
this.cscore = true;
this.rcontext.put("CSCORE",this.collapseScore);
}
if(this.needsScores) {
this.scores = new float[ords.length];
if(nullPolicy == CollapsingPostFilter.NULL_POLICY_EXPAND) {
nullScores = new FloatArrayList();
}
}
}
public void setNextReader(AtomicReaderContext context) throws IOException {
functionValues = this.valueSource.getValues(rcontext, context);
}
public void collapse(int ord, int contextDoc, int globalDoc) throws IOException {
if(needsScores || cscore) {
this.score = scorer.score();
this.collapseScore.score = score;
}
float val = functionValues.floatVal(contextDoc);
if(ord > -1) {
if(comp.test(val, ordVals[ord])) {
ords[ord] = globalDoc;
ordVals[ord] = val;
if(needsScores) {
scores[ord] = score;
}
}
} else if (this.collapsedSet.fastGet(globalDoc)) {
//Elevated doc so do nothing
} else if(this.nullPolicy == CollapsingPostFilter.NULL_POLICY_COLLAPSE) {
if(comp.test(val, nullVal)) {
nullVal = val;
nullDoc = globalDoc;
if(needsScores) {
nullScore = score;
}
}
} else if(this.nullPolicy == CollapsingPostFilter.NULL_POLICY_EXPAND) {
this.collapsedSet.fastSet(globalDoc);
if(needsScores) {
nullScores.add(score);
}
}
}
}
public static final class CollapseScore {
public float score;
}
private interface IntCompare { private interface IntCompare {
public boolean test(int i1, int i2); public boolean test(int i1, int i2);
} }

View File

@ -41,6 +41,7 @@ import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrRequestInfo; import org.apache.solr.request.SolrRequestInfo;
import org.apache.solr.schema.*; import org.apache.solr.schema.*;
import org.apache.solr.search.function.CollapseScoreFunction;
import org.apache.solr.search.function.distance.*; import org.apache.solr.search.function.distance.*;
import org.apache.solr.util.plugin.NamedListInitializedPlugin; import org.apache.solr.util.plugin.NamedListInitializedPlugin;
@ -221,6 +222,12 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin {
}; };
} }
}); });
addParser("cscore", new ValueSourceParser() {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {
return new CollapseScoreFunction();
}
});
addParser("sum", new ValueSourceParser() { addParser("sum", new ValueSourceParser() {
@Override @Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError { public ValueSource parse(FunctionQParser fp) throws SyntaxError {

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search.function;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.solr.search.CollapsingQParserPlugin.CollapseScore;
import java.util.Map;
import java.io.IOException;
public class CollapseScoreFunction extends ValueSource {
public String description() {
return "CollapseScoreFunction";
}
public boolean equals(Object o) {
if(o instanceof CollapseScoreFunction){
return true;
} else {
return false;
}
}
public int hashCode() {
return 1213241257;
}
public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException {
return new CollapseScoreFunctionValues(context);
}
public class CollapseScoreFunctionValues extends FunctionValues {
private CollapseScore cscore;
public CollapseScoreFunctionValues(Map context) {
this.cscore = (CollapseScore) context.get("CSCORE");
}
public int intVal(int doc) {
return 0;
}
public String toString(int doc) {
return Float.toString(cscore.score);
}
public float floatVal(int doc) {
return cscore.score;
}
public double doubleVal(int doc) {
return 0.0D;
}
}
}

View File

@ -343,6 +343,11 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
public void testFuncRord() throws Exception { public void testFuncRord() throws Exception {
assertFuncEquals("rord(foo_s)","rord(foo_s )"); assertFuncEquals("rord(foo_s)","rord(foo_s )");
} }
public void testFuncCscore() throws Exception {
assertFuncEquals("cscore()", "cscore( )");
}
public void testFuncTop() throws Exception { public void testFuncTop() throws Exception {
assertFuncEquals("top(sum(3,foo_i))"); assertFuncEquals("top(sum(3,foo_i))");
} }

View File

@ -95,6 +95,40 @@ public class TestCollapseQParserPlugin extends SolrTestCaseJ4 {
"//result/doc[4]/float[@name='id'][.='6.0']" "//result/doc[4]/float[@name='id'][.='6.0']"
); );
// Test value source collapse criteria
params = new ModifiableSolrParams();
params.add("q", "*:*");
params.add("fq", "{!collapse field=group_s nullPolicy=collapse min=field(test_ti)}");
params.add("sort", "test_ti desc");
assertQ(req(params), "*[count(//doc)=3]",
"//result/doc[1]/float[@name='id'][.='4.0']",
"//result/doc[2]/float[@name='id'][.='1.0']",
"//result/doc[3]/float[@name='id'][.='5.0']"
);
// Test value source collapse criteria with cscore function
params = new ModifiableSolrParams();
params.add("q", "*:*");
params.add("fq", "{!collapse field=group_s nullPolicy=collapse min=cscore()}");
params.add("defType", "edismax");
params.add("bf", "field(test_ti)");
assertQ(req(params), "*[count(//doc)=3]",
"//result/doc[1]/float[@name='id'][.='4.0']",
"//result/doc[2]/float[@name='id'][.='1.0']",
"//result/doc[3]/float[@name='id'][.='5.0']"
);
// Test value source collapse criteria with compound cscore function
params = new ModifiableSolrParams();
params.add("q", "*:*");
params.add("fq", "{!collapse field=group_s nullPolicy=collapse min=sum(cscore(),field(test_ti))}");
params.add("defType", "edismax");
params.add("bf", "field(test_ti)");
assertQ(req(params), "*[count(//doc)=3]",
"//result/doc[1]/float[@name='id'][.='4.0']",
"//result/doc[2]/float[@name='id'][.='1.0']",
"//result/doc[3]/float[@name='id'][.='5.0']"
);
//Test collapse by score with elevation //Test collapse by score with elevation