LUCENE-6695: Added BlendedTermQuery.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1692848 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Adrien Grand 2015-07-27 09:36:31 +00:00
parent 5d16b97396
commit e1558fef1c
8 changed files with 463 additions and 74 deletions

View File

@ -144,6 +144,9 @@ New Features
* LUCENE-6694: Add LithuanianAnalyzer and LithuanianStemmer.
(Dainius Jocas via Robert Muir)
* LUCENE-6695: Added a new BlendedTermQuery to blend statistics across several
terms. (Simon Willnauer, Adrien Grand)
API Changes
* LUCENE-6508: Simplify Lock api, there is now just

View File

@ -17,7 +17,6 @@ package org.apache.lucene.index;
* limitations under the License.
*/
import org.apache.lucene.codecs.BlockTermState;
import org.apache.lucene.util.BytesRef;
import java.io.IOException;

View File

@ -0,0 +1,320 @@
package org.apache.lucene.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.index.TermState;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InPlaceMergeSorter;
/**
* A {@link Query} that blends index statistics across multiple terms.
* This is particularly useful when several terms should produce identical
* scores, regardless of their index statistics.
* <p>For instance imagine that you are resolving synonyms at search time,
* all terms should produce identical scores instead of the default behavior,
* which tends to give higher scores to rare terms.
* <p>An other useful use-case is cross-field search: imagine that you would
* like to search for {@code john} on two fields: {@code first_name} and
* {@code last_name}. You might not want to give a higher weight to matches
* on the field where {@code john} is rarer, in which case
* {@link BlendedTermQuery} would help as well.
* @lucene.experimental
*/
public final class BlendedTermQuery extends Query {
/** A Builder for {@link BlendedTermQuery}. */
public static class Builder {
private int numTerms = 0;
private Term[] terms = new Term[0];
private float[] boosts = new float[0];
private TermContext[] contexts = new TermContext[0];
private RewriteMethod rewriteMethod = DISJUNCTION_MAX_REWRITE;
/** Sole constructor. */
public Builder() {}
/** Set the {@link RewriteMethod}. Default is to use
* {@link BlendedTermQuery#DISJUNCTION_MAX_REWRITE}.
* @see RewriteMethod */
public Builder setRewriteMethod(RewriteMethod rewiteMethod) {
this.rewriteMethod = rewiteMethod;
return this;
}
/** Add a new {@link Term} to this builder, with a default boost of {@code 1}.
* @see #add(Term, float) */
public Builder add(Term term) {
return add(term, 1f);
}
/** Add a {@link Term} with the provided boost. The higher the boost, the
* more this term will contribute to the overall score of the
* {@link BlendedTermQuery}. */
public Builder add(Term term, float boost) {
return add(term, boost, null);
}
/**
* Expert: Add a {@link Term} with the provided boost and context.
* This method is useful if you already have a {@link TermContext}
* object constructed for the given term.
*/
public Builder add(Term term, float boost, TermContext context) {
if (numTerms >= BooleanQuery.getMaxClauseCount()) {
throw new BooleanQuery.TooManyClauses();
}
terms = ArrayUtil.grow(terms, numTerms + 1);
boosts = ArrayUtil.grow(boosts, numTerms + 1);
contexts = ArrayUtil.grow(contexts, numTerms + 1);
terms[numTerms] = new Term(term.field(), BytesRef.deepCopyOf(term.bytes()));
boosts[numTerms] = boost;
contexts[numTerms] = context;
numTerms += 1;
return this;
}
/** Build the {@link BlendedTermQuery}. */
public BlendedTermQuery build() {
return new BlendedTermQuery(
Arrays.copyOf(terms, numTerms),
Arrays.copyOf(boosts, numTerms),
Arrays.copyOf(contexts, numTerms),
rewriteMethod);
}
}
/** A {@link RewriteMethod} defines how queries for individual terms should
* be merged.
* @lucene.experimental
* @see BlendedTermQuery#BOOLEAN_REWRITE
* @see BlendedTermQuery.DisjunctionMaxRewrite */
public static abstract class RewriteMethod {
/** Sole constructor */
protected RewriteMethod() {}
/** Merge the provided sub queries into a single {@link Query} object. */
public abstract Query rewrite(Query[] subQueries);
}
/**
* A {@link RewriteMethod} that adds all sub queries to a {@link BooleanQuery}
* which has {@link BooleanQuery#isCoordDisabled() coords disabled}. This
* {@link RewriteMethod} is useful when matching on several fields is
* considered better than having a good match on a single field.
*/
public static final RewriteMethod BOOLEAN_REWRITE = new RewriteMethod() {
@Override
public Query rewrite(Query[] subQueries) {
BooleanQuery.Builder merged = new BooleanQuery.Builder();
merged.setDisableCoord(true);
for (Query query : subQueries) {
merged.add(query, Occur.SHOULD);
}
return merged.build();
}
};
/**
* A {@link RewriteMethod} that creates a {@link DisjunctionMaxQuery} out
* of the sub queries. This {@link RewriteMethod} is useful when having a
* good match on a single field is considered better than having average
* matches on several fields.
*/
public static class DisjunctionMaxRewrite extends RewriteMethod {
private final float tieBreakerMultiplier;
/** This {@link RewriteMethod} will create {@link DisjunctionMaxQuery}
* instances that have the provided tie breaker.
* @see DisjunctionMaxQuery */
public DisjunctionMaxRewrite(float tieBreakerMultiplier) {
this.tieBreakerMultiplier = tieBreakerMultiplier;
}
@Override
public Query rewrite(Query[] subQueries) {
return new DisjunctionMaxQuery(Arrays.asList(subQueries), tieBreakerMultiplier);
}
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DisjunctionMaxRewrite that = (DisjunctionMaxRewrite) obj;
return tieBreakerMultiplier == that.tieBreakerMultiplier;
}
@Override
public int hashCode() {
return 31 * getClass().hashCode() + Float.floatToIntBits(tieBreakerMultiplier);
}
}
/** {@link DisjunctionMaxRewrite} instance with a tie-breaker of {@code 0.01}. */
public static final RewriteMethod DISJUNCTION_MAX_REWRITE = new DisjunctionMaxRewrite(0.01f);
private final Term[] terms;
private final float[] boosts;
private final TermContext[] contexts;
private final RewriteMethod rewriteMethod;
private BlendedTermQuery(Term[] terms, float[] boosts, TermContext[] contexts,
RewriteMethod rewriteMethod) {
assert terms.length == boosts.length;
assert terms.length == contexts.length;
this.terms = terms;
this.boosts = boosts;
this.contexts = contexts;
this.rewriteMethod = rewriteMethod;
// we sort terms so that equals/hashcode does not rely on the order
new InPlaceMergeSorter() {
@Override
protected void swap(int i, int j) {
Term tmpTerm = terms[i];
terms[i] = terms[j];
terms[j] = tmpTerm;
TermContext tmpContext = contexts[i];
contexts[i] = contexts[j];
contexts[j] = tmpContext;
float tmpBoost = boosts[i];
boosts[i] = boosts[j];
boosts[j] = tmpBoost;
}
@Override
protected int compare(int i, int j) {
return terms[i].compareTo(terms[j]);
}
}.sort(0, terms.length);
}
@Override
public boolean equals(Object obj) {
if (super.equals(obj) == false) {
return false;
}
BlendedTermQuery that = (BlendedTermQuery) obj;
return Arrays.equals(terms, that.terms)
&& Arrays.equals(contexts, that.contexts)
&& Arrays.equals(boosts, that.boosts)
&& rewriteMethod.equals(that.rewriteMethod);
}
@Override
public int hashCode() {
int h = super.hashCode();
h = 31 * h + Arrays.hashCode(terms);
h = 31 * h + Arrays.hashCode(contexts);
h = 31 * h + Arrays.hashCode(boosts);
h = 31 * h + rewriteMethod.hashCode();
return h;
}
@Override
public String toString(String field) {
StringBuilder builder = new StringBuilder("Blended(");
for (int i = 0; i < terms.length; ++i) {
if (i != 0) {
builder.append(" ");
}
TermQuery termQuery = new TermQuery(terms[i]);
termQuery.setBoost(boosts[i]);
builder.append(termQuery.toString(field));
}
builder.append(")");
return builder.toString();
}
@Override
public final Query rewrite(IndexReader reader) throws IOException {
final TermContext[] contexts = Arrays.copyOf(this.contexts, this.contexts.length);
for (int i = 0; i < contexts.length; ++i) {
if (contexts[i] == null || contexts[i].topReaderContext != reader.getContext()) {
contexts[i] = TermContext.build(reader.getContext(), terms[i]);
}
}
// Compute aggregated doc freq and total term freq
// df will be the max of all doc freqs
// ttf will be the sum of all total term freqs
int df = 0;
long ttf = 0;
for (TermContext ctx : contexts) {
df = Math.max(df, ctx.docFreq());
if (ctx.totalTermFreq() == -1L) {
ttf = -1L;
} else if (ttf != -1L) {
ttf += ctx.totalTermFreq();
}
}
for (int i = 0; i < contexts.length; ++i) {
contexts[i] = adjustFrequencies(contexts[i], df, ttf);
}
TermQuery[] termQueries = new TermQuery[terms.length];
for (int i = 0; i < terms.length; ++i) {
termQueries[i] = new TermQuery(terms[i], contexts[i]);
termQueries[i].setBoost(boosts[i]);
}
Query rewritten = rewriteMethod.rewrite(termQueries);
rewritten.setBoost(getBoost());
return rewritten;
}
private static TermContext adjustFrequencies(TermContext ctx, int artificialDf, long artificialTtf) {
List<LeafReaderContext> leaves = ctx.topReaderContext.leaves();
final int len;
if (leaves == null) {
len = 1;
} else {
len = leaves.size();
}
TermContext newCtx = new TermContext(ctx.topReaderContext);
for (int i = 0; i < len; ++i) {
TermState termState = ctx.get(i);
if (termState == null) {
continue;
}
newCtx.register(termState, i);
}
newCtx.accumulateStatistics(artificialDf, artificialTtf);
return newCtx;
}
}

View File

@ -18,16 +18,13 @@ package org.apache.lucene.search;
*/
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.FilteredTermsEnum; // javadocs
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SingleTermsEnum; // javadocs
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanQuery.Builder;
@ -194,7 +191,7 @@ public abstract class MultiTermQuery extends Query {
* @see #setRewriteMethod
*/
public static final class TopTermsBlendedFreqScoringRewrite extends
TopTermsRewrite<BooleanQuery.Builder> {
TopTermsRewrite<BlendedTermQuery.Builder> {
/**
* Create a TopTermsBlendedScoringBooleanQueryRewrite for at most
@ -213,77 +210,22 @@ public abstract class MultiTermQuery extends Query {
}
@Override
protected BooleanQuery.Builder getTopLevelBuilder() {
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.setDisableCoord(true);
protected BlendedTermQuery.Builder getTopLevelBuilder() {
BlendedTermQuery.Builder builder = new BlendedTermQuery.Builder();
builder.setRewriteMethod(BlendedTermQuery.BOOLEAN_REWRITE);
return builder;
}
@Override
protected Query build(Builder builder) {
protected Query build(BlendedTermQuery.Builder builder) {
return builder.build();
}
@Override
protected void addClause(BooleanQuery.Builder topLevel, Term term, int docCount,
protected void addClause(BlendedTermQuery.Builder topLevel, Term term, int docCount,
float boost, TermContext states) {
final TermQuery tq = new TermQuery(term, states);
tq.setBoost(boost);
topLevel.add(tq, BooleanClause.Occur.SHOULD);
topLevel.add(term, boost, states);
}
@Override
void adjustScoreTerms(IndexReader reader, ScoreTerm[] scoreTerms) {
if (scoreTerms.length <= 1) {
return;
}
int maxDf = 0;
long maxTtf = 0;
for (ScoreTerm scoreTerm : scoreTerms) {
TermContext ctx = scoreTerm.termState;
int df = ctx.docFreq();
maxDf = Math.max(df, maxDf);
long ttf = ctx.totalTermFreq();
maxTtf = ttf == -1 || maxTtf == -1 ? -1 : Math.max(ttf, maxTtf);
}
assert maxDf >= 0 : "DF must be >= 0";
if (maxDf == 0) {
return; // we are done that term doesn't exist at all
}
assert (maxTtf == -1) || (maxTtf >= maxDf);
for (int i = 0; i < scoreTerms.length; i++) {
TermContext ctx = scoreTerms[i].termState;
ctx = adjustFrequencies(ctx, maxDf, maxTtf);
ScoreTerm adjustedScoreTerm = new ScoreTerm(ctx);
adjustedScoreTerm.boost = scoreTerms[i].boost;
adjustedScoreTerm.bytes.copyBytes(scoreTerms[i].bytes);
scoreTerms[i] = adjustedScoreTerm;
}
}
}
private static TermContext adjustFrequencies(TermContext ctx, int artificialDf,
long artificialTtf) {
List<LeafReaderContext> leaves = ctx.topReaderContext.leaves();
final int len;
if (leaves == null) {
len = 1;
} else {
len = leaves.size();
}
TermContext newCtx = new TermContext(ctx.topReaderContext);
for (int i = 0; i < len; ++i) {
TermState termState = ctx.get(i);
if (termState == null) {
continue;
}
newCtx.register(termState, i);
}
newCtx.accumulateStatistics(artificialDf, artificialTtf);
return newCtx;
}
/**

View File

@ -157,8 +157,6 @@ public abstract class TopTermsRewrite<B> extends TermCollectingRewrite<B> {
final B b = getTopLevelBuilder();
final ScoreTerm[] scoreTerms = stQueue.toArray(new ScoreTerm[stQueue.size()]);
ArrayUtil.timSort(scoreTerms, scoreTermSortByTermComp);
adjustScoreTerms(reader, scoreTerms);
for (final ScoreTerm st : scoreTerms) {
final Term term = new Term(query.field, st.bytes.toBytesRef());
@ -167,10 +165,6 @@ public abstract class TopTermsRewrite<B> extends TermCollectingRewrite<B> {
return build(b);
}
void adjustScoreTerms(IndexReader reader, ScoreTerm[] scoreTerms) {
//no-op but allows subclasses the ability to tweak the score terms used in ranking e.g. balancing IDF.
}
@Override
public int hashCode() {
return 31 * size;

View File

@ -17,6 +17,7 @@ package org.apache.lucene.util;
* limitations under the License.
*/
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
@ -236,6 +237,14 @@ public final class ArrayUtil {
return currentSize;
}
public static <T> T[] grow(T[] array, int minSize) {
assert minSize >= 0: "size must be positive (got " + minSize + "): likely integer overflow?";
if (array.length < minSize) {
return Arrays.copyOf(array, oversize(minSize, RamUsageEstimator.NUM_BYTES_OBJECT_REF));
} else
return array;
}
public static short[] grow(short[] array, int minSize) {
assert minSize >= 0: "size must be positive (got " + minSize + "): likely integer overflow?";
if (array.length < minSize) {

View File

@ -0,0 +1,122 @@
package org.apache.lucene.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
public class TestBlendedTermQuery extends LuceneTestCase {
public void testEquals() {
Term t1 = new Term("foo", "bar");
BlendedTermQuery bt1 = new BlendedTermQuery.Builder()
.add(t1)
.build();
BlendedTermQuery bt2 = new BlendedTermQuery.Builder()
.add(t1)
.build();
QueryUtils.checkEqual(bt1, bt2);
bt1 = new BlendedTermQuery.Builder()
.setRewriteMethod(BlendedTermQuery.BOOLEAN_REWRITE)
.add(t1)
.build();
bt2 = new BlendedTermQuery.Builder()
.setRewriteMethod(BlendedTermQuery.DISJUNCTION_MAX_REWRITE)
.add(t1)
.build();
QueryUtils.checkUnequal(bt1, bt2);
Term t2 = new Term("foo", "baz");
bt1 = new BlendedTermQuery.Builder()
.add(t1)
.add(t2)
.build();
bt2 = new BlendedTermQuery.Builder()
.add(t2)
.add(t1)
.build();
QueryUtils.checkEqual(bt1, bt2);
float boost1 = random().nextFloat();
float boost2 = random().nextFloat();
bt1 = new BlendedTermQuery.Builder()
.add(t1, boost1)
.add(t2, boost2)
.build();
bt2 = new BlendedTermQuery.Builder()
.add(t2, boost2)
.add(t1, boost1)
.build();
QueryUtils.checkEqual(bt1, bt2);
}
public void testToString() {
assertEquals("Blended()", new BlendedTermQuery.Builder().build().toString());
Term t1 = new Term("foo", "bar");
assertEquals("Blended(foo:bar)", new BlendedTermQuery.Builder().add(t1).build().toString());
Term t2 = new Term("foo", "baz");
assertEquals("Blended(foo:bar foo:baz)", new BlendedTermQuery.Builder().add(t1).add(t2).build().toString());
assertEquals("Blended(foo:bar^4.0 foo:baz^3.0)", new BlendedTermQuery.Builder().add(t1, 4).add(t2, 3).build().toString());
}
public void testBlendedScores() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(new StringField("f", "a", Store.NO));
w.addDocument(doc);
doc = new Document();
doc.add(new StringField("f", "b", Store.NO));
for (int i = 0; i < 10; ++i) {
w.addDocument(doc);
}
IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
BlendedTermQuery query = new BlendedTermQuery.Builder()
.setRewriteMethod(new BlendedTermQuery.DisjunctionMaxRewrite(0f))
.add(new Term("f", "a"))
.add(new Term("f", "b"))
.build();
TopDocs topDocs = searcher.search(query, 20);
assertEquals(11, topDocs.totalHits);
// All docs must have the same score
for (int i = 0; i < topDocs.scoreDocs.length; ++i) {
assertEquals(topDocs.scoreDocs[0].score, topDocs.scoreDocs[i].score, 0.0f);
}
reader.close();
w.close();
dir.close();
}
}

View File

@ -19,7 +19,6 @@ package org.apache.lucene.queryparser.complexPhrase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
@ -30,6 +29,7 @@ import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
@ -273,7 +273,7 @@ public class ComplexPhraseQueryParser extends QueryParser {
// HashSet bclauseterms=new HashSet();
Query qc = clause.getQuery();
// Rewrite this clause e.g one* becomes (one OR onerous)
qc = qc.rewrite(reader);
qc = new IndexSearcher(reader).rewrite(qc);
if (clause.getOccur().equals(BooleanClause.Occur.MUST_NOT)) {
numNegatives++;
}