LUCENE-7623: Add FunctionMatchQuery and FunctionScoreQuery

This commit is contained in:
Alan Woodward 2017-01-15 18:37:41 +00:00
parent ceaeb42a1f
commit fc2e0fd133
7 changed files with 611 additions and 16 deletions

View File

@ -63,6 +63,11 @@ Other
======================= Lucene 6.5.0 =======================
New Features
* LUCENE-7623: Add FunctionScoreQuery and FunctionMatchQuery (Alan Woodward,
Adrien Grand, David Smiley)
Bug Fixes
* LUCENE-7630: Fix (Edge)NGramTokenFilter to no longer drop payloads

View File

@ -20,7 +20,9 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Objects;
import java.util.function.DoubleToLongFunction;
import java.util.function.DoubleUnaryOperator;
import java.util.function.LongToDoubleFunction;
import java.util.function.ToDoubleBiFunction;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
@ -173,6 +175,69 @@ public abstract class DoubleValuesSource {
public boolean needsScores() {
return false;
}
@Override
public String toString() {
return "constant(" + value + ")";
}
};
}
/**
* Creates a DoubleValuesSource that is a function of another DoubleValuesSource
*/
public static DoubleValuesSource function(DoubleValuesSource in, DoubleUnaryOperator function) {
return new DoubleValuesSource() {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
DoubleValues inputs = in.getValues(ctx, scores);
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return function.applyAsDouble(inputs.doubleValue());
}
@Override
public boolean advanceExact(int doc) throws IOException {
return inputs.advanceExact(doc);
}
};
}
@Override
public boolean needsScores() {
return in.needsScores();
}
};
}
/**
* Creates a DoubleValuesSource that is a function of another DoubleValuesSource and a score
* @param in the DoubleValuesSource to use as an input
* @param function a function of the form (source, score) == result
*/
public static DoubleValuesSource scoringFunction(DoubleValuesSource in, ToDoubleBiFunction<Double, Double> function) {
return new DoubleValuesSource() {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
DoubleValues inputs = in.getValues(ctx, scores);
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return function.applyAsDouble(inputs.doubleValue(), scores.doubleValue());
}
@Override
public boolean advanceExact(int doc) throws IOException {
return inputs.advanceExact(doc);
}
};
}
@Override
public boolean needsScores() {
return true;
}
};
}
@ -221,7 +286,17 @@ public abstract class DoubleValuesSource {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final NumericDocValues values = DocValues.getNumeric(ctx.reader(), field);
return toDoubleValues(values, decoder::applyAsDouble);
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return decoder.applyAsDouble(values.longValue());
}
@Override
public boolean advanceExact(int target) throws IOException {
return values.advanceExact(target);
}
};
}
@Override
@ -288,21 +363,6 @@ public abstract class DoubleValuesSource {
}
}
private static DoubleValues toDoubleValues(NumericDocValues in, LongToDoubleFunction map) {
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return map.applyAsDouble(in.longValue());
}
@Override
public boolean advanceExact(int target) throws IOException {
return in.advanceExact(target);
}
};
}
private static NumericDocValues asNumericDocValues(DoubleValuesHolder in, DoubleToLongFunction converter) {
return new NumericDocValues() {
@Override

View File

@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function;
import java.io.IOException;
import java.util.Objects;
import java.util.function.DoublePredicate;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
/**
* A query that retrieves all documents with a {@link DoubleValues} value matching a predicate
*
* This query works by a linear scan of the index, and is best used in
* conjunction with other queries that can restrict the number of
* documents visited
*/
public final class FunctionMatchQuery extends Query {
private final DoubleValuesSource source;
private final DoublePredicate filter;
/**
* Create a FunctionMatchQuery
* @param source a {@link DoubleValuesSource} to use for values
* @param filter the predicate to match against
*/
public FunctionMatchQuery(DoubleValuesSource source, DoublePredicate filter) {
this.source = source;
this.filter = filter;
}
@Override
public String toString(String field) {
return "FunctionMatchQuery(" + source.toString() + ")";
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
DoubleValues values = source.getValues(context, null);
DocIdSetIterator approximation = DocIdSetIterator.all(context.reader().maxDoc());
TwoPhaseIterator twoPhase = new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
return values.advanceExact(approximation.docID()) && filter.test(values.doubleValue());
}
@Override
public float matchCost() {
return 100; // TODO maybe DoubleValuesSource should have a matchCost?
}
};
return new ConstantScoreScorer(this, score(), twoPhase);
}
};
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FunctionMatchQuery that = (FunctionMatchQuery) o;
return Objects.equals(source, that.source) && Objects.equals(filter, that.filter);
}
@Override
public int hashCode() {
return Objects.hash(source, filter);
}
}

View File

@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterScorer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
/**
* A query that wraps another query, and uses a DoubleValuesSource to
* replace or modify the wrapped query's score
*
* If the DoubleValuesSource doesn't return a value for a particular document,
* then that document will be given a score of 0.
*/
public final class FunctionScoreQuery extends Query {
private final Query in;
private final DoubleValuesSource source;
/**
* Create a new FunctionScoreQuery
* @param in the query to wrap
* @param source a source of scores
*/
public FunctionScoreQuery(Query in, DoubleValuesSource source) {
this.in = in;
this.source = source;
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
Weight inner = in.createWeight(searcher, needsScores && source.needsScores(), 1f);
if (needsScores == false)
return inner;
return new FunctionScoreWeight(this, inner, source, boost);
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = in.rewrite(reader);
if (rewritten == in)
return this;
return new FunctionScoreQuery(rewritten, source);
}
@Override
public String toString(String field) {
return "FunctionScoreQuery(" + in.toString(field) + ", scored by " + source.toString() + ")";
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FunctionScoreQuery that = (FunctionScoreQuery) o;
return Objects.equals(in, that.in) &&
Objects.equals(source, that.source);
}
@Override
public int hashCode() {
return Objects.hash(in, source);
}
private static class FunctionScoreWeight extends Weight {
final Weight inner;
final DoubleValuesSource valueSource;
final float boost;
FunctionScoreWeight(Query query, Weight inner, DoubleValuesSource valueSource, float boost) {
super(query);
this.inner = inner;
this.valueSource = valueSource;
this.boost = boost;
}
@Override
public void extractTerms(Set<Term> terms) {
this.inner.extractTerms(terms);
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Scorer scorer = inner.scorer(context);
if (scorer.iterator().advance(doc) != doc)
return Explanation.noMatch("No match");
DoubleValues scores = valueSource.getValues(context, DoubleValuesSource.fromScorer(scorer));
scores.advanceExact(doc);
Explanation scoreExpl = scoreExplanation(context, doc, scores);
if (boost == 1f)
return scoreExpl;
return Explanation.match(scoreExpl.getValue() * boost, "product of:",
Explanation.match(boost, "boost"), scoreExpl);
}
private Explanation scoreExplanation(LeafReaderContext context, int doc, DoubleValues scores) throws IOException {
if (valueSource.needsScores() == false)
return Explanation.match((float) scores.doubleValue(), valueSource.toString());
float score = (float) scores.doubleValue();
return Explanation.match(score, "computed from:",
Explanation.match(score, valueSource.toString()),
inner.explain(context, doc));
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer in = inner.scorer(context);
if (in == null)
return null;
DoubleValues scores = valueSource.getValues(context, DoubleValuesSource.fromScorer(in));
return new FilterScorer(in) {
@Override
public float score() throws IOException {
if (scores.advanceExact(docID()))
return (float) (scores.doubleValue() * boost);
else
return 0;
}
};
}
}
}

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function;
import java.io.IOException;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.QueryUtils;
import org.apache.lucene.search.TopDocs;
import org.junit.AfterClass;
import org.junit.BeforeClass;
public class TestFunctionMatchQuery extends FunctionTestSetup {
static IndexReader reader;
static IndexSearcher searcher;
@BeforeClass
public static void beforeClass() throws Exception {
createIndex(true);
reader = DirectoryReader.open(dir);
searcher = new IndexSearcher(reader);
}
@AfterClass
public static void afterClass() throws Exception {
reader.close();
}
public void testRangeMatching() throws IOException {
DoubleValuesSource in = DoubleValuesSource.fromFloatField(FLOAT_FIELD);
FunctionMatchQuery fmq = new FunctionMatchQuery(in, d -> d >= 2 && d < 4);
TopDocs docs = searcher.search(fmq, 10);
assertEquals(2, docs.totalHits);
assertEquals(9, docs.scoreDocs[0].doc);
assertEquals(13, docs.scoreDocs[1].doc);
QueryUtils.check(random(), fmq, searcher, rarely());
}
}

View File

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function;
import java.io.IOException;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BaseExplanationTestCase;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
public class TestFunctionScoreExplanations extends BaseExplanationTestCase {
public void testOneTerm() throws Exception {
Query q = new TermQuery(new Term(FIELD, "w1"));
FunctionScoreQuery fsq = new FunctionScoreQuery(q, DoubleValuesSource.constant(5));
qtest(fsq, new int[] { 0,1,2,3 });
}
public void testBoost() throws Exception {
Query q = new TermQuery(new Term(FIELD, "w1"));
FunctionScoreQuery csq = new FunctionScoreQuery(q, DoubleValuesSource.constant(5));
qtest(new BoostQuery(csq, 4), new int[] { 0,1,2,3 });
}
public void testTopLevelBoost() throws Exception {
Query q = new TermQuery(new Term(FIELD, "w1"));
FunctionScoreQuery csq = new FunctionScoreQuery(q, DoubleValuesSource.constant(5));
BooleanQuery.Builder bqB = new BooleanQuery.Builder();
bqB.add(new MatchAllDocsQuery(), BooleanClause.Occur.MUST);
bqB.add(csq, BooleanClause.Occur.MUST);
BooleanQuery bq = bqB.build();
qtest(new BoostQuery(bq, 6), new int[] { 0,1,2,3 });
}
public void testExplanationsIncludingScore() throws Exception {
DoubleValuesSource scores = DoubleValuesSource.function(DoubleValuesSource.SCORES, v -> v * 2);
Query q = new TermQuery(new Term(FIELD, "w1"));
FunctionScoreQuery csq = new FunctionScoreQuery(q, scores);
qtest(csq, new int[] { 0, 1, 2, 3 });
Explanation e1 = searcher.explain(q, 0);
Explanation e = searcher.explain(csq, 0);
assertEquals(e.getDetails().length, 2);
assertEquals(e1.getValue() * 2, e.getValue(), 0.00001);
}
public void testSubExplanations() throws IOException {
Query query = new FunctionScoreQuery(new MatchAllDocsQuery(), DoubleValuesSource.constant(5));
IndexSearcher searcher = newSearcher(BaseExplanationTestCase.searcher.getIndexReader());
searcher.setSimilarity(new BM25Similarity());
Explanation expl = searcher.explain(query, 0);
assertEquals("constant(5.0)", expl.getDescription());
assertEquals(0, expl.getDetails().length);
query = new BoostQuery(query, 2);
expl = searcher.explain(query, 0);
assertEquals(2, expl.getDetails().length);
// function
assertEquals(5f, expl.getDetails()[1].getValue(), 0f);
// boost
assertEquals("boost", expl.getDetails()[0].getDescription());
assertEquals(2f, expl.getDetails()[0].getValue(), 0f);
searcher.setSimilarity(new ClassicSimilarity()); // in order to have a queryNorm != 1
expl = searcher.explain(query, 0);
assertEquals(2, expl.getDetails().length);
// function
assertEquals(5f, expl.getDetails()[1].getValue(), 0f);
// boost
assertEquals("boost", expl.getDetails()[0].getDescription());
assertEquals(2f, expl.getDetails()[0].getValue(), 0f);
}
}

View File

@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryUtils;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.junit.AfterClass;
import org.junit.BeforeClass;
public class TestFunctionScoreQuery extends FunctionTestSetup {
static IndexReader reader;
static IndexSearcher searcher;
@BeforeClass
public static void beforeClass() throws Exception {
createIndex(true);
reader = DirectoryReader.open(dir);
searcher = new IndexSearcher(reader);
}
@AfterClass
public static void afterClass() throws Exception {
reader.close();
}
// FunctionQuery equivalent
public void testSimpleSourceScore() throws Exception {
FunctionScoreQuery q = new FunctionScoreQuery(new TermQuery(new Term(TEXT_FIELD, "first")),
DoubleValuesSource.fromIntField(INT_FIELD));
QueryUtils.check(random(), q, searcher, rarely());
int expectedDocs[] = new int[]{ 4, 7, 9 };
TopDocs docs = searcher.search(q, 4);
assertEquals(expectedDocs.length, docs.totalHits);
for (int i = 0; i < expectedDocs.length; i++) {
assertEquals(docs.scoreDocs[i].doc, expectedDocs[i]);
}
}
// CustomScoreQuery and BoostedQuery equivalent
public void testScoreModifyingSource() throws Exception {
DoubleValuesSource iii = DoubleValuesSource.fromIntField("iii");
DoubleValuesSource score = DoubleValuesSource.scoringFunction(iii, (v, s) -> v * s);
BooleanQuery bq = new BooleanQuery.Builder()
.add(new TermQuery(new Term(TEXT_FIELD, "first")), BooleanClause.Occur.SHOULD)
.add(new TermQuery(new Term(TEXT_FIELD, "text")), BooleanClause.Occur.SHOULD)
.build();
TopDocs plain = searcher.search(bq, 1);
FunctionScoreQuery fq = new FunctionScoreQuery(bq, score);
QueryUtils.check(random(), fq, searcher, rarely());
int[] expectedDocs = new int[]{ 4, 7, 9, 8, 12 };
TopDocs docs = searcher.search(fq, 5);
assertEquals(plain.totalHits, docs.totalHits);
for (int i = 0; i < expectedDocs.length; i++) {
assertEquals(expectedDocs[i], docs.scoreDocs[i].doc);
}
}
// check boosts with non-distributive score source
public void testBoostsAreAppliedLast() throws Exception {
DoubleValuesSource scores
= DoubleValuesSource.function(DoubleValuesSource.SCORES, v -> Math.log(v + 4));
Query q1 = new FunctionScoreQuery(new TermQuery(new Term(TEXT_FIELD, "text")), scores);
TopDocs plain = searcher.search(q1, 5);
Query boosted = new BoostQuery(q1, 2);
TopDocs afterboost = searcher.search(boosted, 5);
assertEquals(plain.totalHits, afterboost.totalHits);
for (int i = 0; i < 5; i++) {
assertEquals(plain.scoreDocs[i].doc, afterboost.scoreDocs[i].doc);
assertEquals(plain.scoreDocs[i].score, afterboost.scoreDocs[i].score / 2, 0.0001);
}
}
}