refactor the custom boost factor query into a more general function boost query

This commit is contained in:
kimchy 2010-06-13 16:51:19 +03:00
parent ec481159d6
commit a9fc276a3e
8 changed files with 177 additions and 69 deletions

View File

@ -23,8 +23,8 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.*; import org.apache.lucene.search.*;
import org.apache.lucene.search.spans.SpanTermQuery; import org.apache.lucene.search.spans.SpanTermQuery;
import org.elasticsearch.util.lucene.search.CustomBoostFactorQuery;
import org.elasticsearch.util.lucene.search.TermFilter; import org.elasticsearch.util.lucene.search.TermFilter;
import org.elasticsearch.util.lucene.search.function.FunctionScoreQuery;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
@ -77,8 +77,8 @@ public class CustomFieldQuery extends FieldQuery {
flatten(((ConstantScoreQuery) sourceQuery).getFilter(), flatQueries); flatten(((ConstantScoreQuery) sourceQuery).getFilter(), flatQueries);
} else if (sourceQuery instanceof DeletionAwareConstantScoreQuery) { } else if (sourceQuery instanceof DeletionAwareConstantScoreQuery) {
flatten(((DeletionAwareConstantScoreQuery) sourceQuery).getFilter(), flatQueries); flatten(((DeletionAwareConstantScoreQuery) sourceQuery).getFilter(), flatQueries);
} else if (sourceQuery instanceof CustomBoostFactorQuery) { } else if (sourceQuery instanceof FunctionScoreQuery) {
flatten(((CustomBoostFactorQuery) sourceQuery).getSubQuery(), flatQueries); flatten(((FunctionScoreQuery) sourceQuery).getSubQuery(), flatQueries);
} else if (sourceQuery instanceof MultiTermQuery) { } else if (sourceQuery instanceof MultiTermQuery) {
MultiTermQuery multiTermQuery = (MultiTermQuery) sourceQuery; MultiTermQuery multiTermQuery = (MultiTermQuery) sourceQuery;
MultiTermQuery.RewriteMethod rewriteMethod = multiTermQuery.getRewriteMethod(); MultiTermQuery.RewriteMethod rewriteMethod = multiTermQuery.getRewriteMethod();

View File

@ -26,7 +26,8 @@ import org.elasticsearch.index.query.QueryParsingException;
import org.elasticsearch.index.settings.IndexSettings; import org.elasticsearch.index.settings.IndexSettings;
import org.elasticsearch.util.Strings; import org.elasticsearch.util.Strings;
import org.elasticsearch.util.inject.Inject; import org.elasticsearch.util.inject.Inject;
import org.elasticsearch.util.lucene.search.CustomBoostFactorQuery; import org.elasticsearch.util.lucene.search.function.BoostFactorFunctionProvider;
import org.elasticsearch.util.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.util.settings.Settings; import org.elasticsearch.util.settings.Settings;
import org.elasticsearch.util.xcontent.XContentParser; import org.elasticsearch.util.xcontent.XContentParser;
@ -74,8 +75,8 @@ public class CustomBoostFactorQueryParser extends AbstractIndexComponent impleme
if (query == null) { if (query == null) {
throw new QueryParsingException(index, "[constant_factor_query] requires 'query' element"); throw new QueryParsingException(index, "[constant_factor_query] requires 'query' element");
} }
CustomBoostFactorQuery customBoostFactorQuery = new CustomBoostFactorQuery(query, boostFactor); FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(query, new BoostFactorFunctionProvider(boostFactor));
customBoostFactorQuery.setBoost(boost); functionScoreQuery.setBoost(boost);
return customBoostFactorQuery; return functionScoreQuery;
} }
} }

View File

@ -28,8 +28,9 @@ import org.elasticsearch.search.facets.FacetsPhase;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.util.collect.ImmutableMap; import org.elasticsearch.util.collect.ImmutableMap;
import org.elasticsearch.util.inject.Inject; import org.elasticsearch.util.inject.Inject;
import org.elasticsearch.util.lucene.search.CustomBoostFactorQuery;
import org.elasticsearch.util.lucene.search.TermFilter; import org.elasticsearch.util.lucene.search.TermFilter;
import org.elasticsearch.util.lucene.search.function.BoostFactorFunctionProvider;
import org.elasticsearch.util.lucene.search.function.FunctionScoreQuery;
import java.util.Map; import java.util.Map;
@ -62,7 +63,7 @@ public class QueryPhase implements SearchPhase {
throw new SearchParseException(context, "No query specified in search request"); throw new SearchParseException(context, "No query specified in search request");
} }
if (context.queryBoost() != 1.0f) { if (context.queryBoost() != 1.0f) {
context.query(new CustomBoostFactorQuery(context.query(), context.queryBoost())); context.query(new FunctionScoreQuery(context.query(), new BoostFactorFunctionProvider(context.queryBoost())));
} }
facetsPhase.preProcess(context); facetsPhase.preProcess(context);
} }

View File

@ -0,0 +1,70 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Elastic Search licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.util.lucene.search.function;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Explanation;
/**
* @author kimchy (shay.banon)
*/
public class BoostFactorFunctionProvider implements FunctionProvider, Function {
private final float boost;
public BoostFactorFunctionProvider(float boost) {
this.boost = boost;
}
public float getBoost() {
return boost;
}
@Override public Function function(IndexReader reader) {
return this;
}
@Override public float score(int docId, float subQueryScore) {
return subQueryScore * boost;
}
@Override public Explanation explain(int docId, Explanation subQueryExpl) {
Explanation exp = new Explanation(boost * subQueryExpl.getValue(), "static boost function: product of:");
exp.addDetail(subQueryExpl);
exp.addDetail(new Explanation(boost, "boostFactor"));
return exp;
}
@Override public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BoostFactorFunctionProvider that = (BoostFactorFunctionProvider) o;
if (Float.compare(that.boost, boost) != 0) return false;
return true;
}
@Override public int hashCode() {
return (boost != +0.0f ? Float.floatToIntBits(boost) : 0);
}
}

View File

@ -0,0 +1,32 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Elastic Search licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.util.lucene.search.function;
import org.apache.lucene.search.Explanation;
/**
* @author kimchy (shay.banon)
*/
public interface Function {
float score(int docId, float subQueryScore);
Explanation explain(int docId, Explanation subQueryExpl);
}

View File

@ -0,0 +1,30 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Elastic Search licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.util.lucene.search.function;
import org.apache.lucene.index.IndexReader;
/**
* @author kimchy (shay.banon)
*/
public interface FunctionProvider {
Function function(IndexReader reader);
}

View File

@ -17,7 +17,7 @@
* under the License. * under the License.
*/ */
package org.elasticsearch.util.lucene.search; package org.elasticsearch.util.lucene.search.function;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
@ -28,34 +28,33 @@ import java.io.IOException;
import java.util.Set; import java.util.Set;
/** /**
* A query that wraps another query and applies the provided boost values to it. Simply * A query that allows for a pluggable boost function to be applied to it.
* applied the boost factor to the score of the wrapped query.
* *
* @author kimchy (shay.banon) * @author kimchy (shay.banon)
*/ */
public class CustomBoostFactorQuery extends Query { public class FunctionScoreQuery extends Query {
private Query subQuery; private Query subQuery;
private float boostFactor; private FunctionProvider functionProvider;
public CustomBoostFactorQuery(Query subQuery, float boostFactor) { public FunctionScoreQuery(Query subQuery, FunctionProvider functionProvider) {
this.subQuery = subQuery; this.subQuery = subQuery;
this.boostFactor = boostFactor; this.functionProvider = functionProvider;
} }
public Query getSubQuery() { public Query getSubQuery() {
return subQuery; return subQuery;
} }
public float getBoostFactor() { public FunctionProvider getFunctionProvider() {
return boostFactor; return functionProvider;
} }
@Override @Override
public Query rewrite(IndexReader reader) throws IOException { public Query rewrite(IndexReader reader) throws IOException {
Query newQ = subQuery.rewrite(reader); Query newQ = subQuery.rewrite(reader);
if (newQ == subQuery) return this; if (newQ == subQuery) return this;
CustomBoostFactorQuery bq = (CustomBoostFactorQuery) this.clone(); FunctionScoreQuery bq = (FunctionScoreQuery) this.clone();
bq.subQuery = newQ; bq.subQuery = newQ;
return bq; return bq;
} }
@ -80,7 +79,7 @@ public class CustomBoostFactorQuery extends Query {
} }
public Query getQuery() { public Query getQuery() {
return CustomBoostFactorQuery.this; return FunctionScoreQuery.this;
} }
public float getValue() { public float getValue() {
@ -106,7 +105,7 @@ public class CustomBoostFactorQuery extends Query {
if (subQueryScorer == null) { if (subQueryScorer == null) {
return null; return null;
} }
return new CustomBoostFactorScorer(getSimilarity(searcher), reader, this, subQueryScorer); return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, functionProvider.function(reader));
} }
@Override @Override
@ -116,29 +115,26 @@ public class CustomBoostFactorQuery extends Query {
return subQueryExpl; return subQueryExpl;
} }
float sc = subQueryExpl.getValue() * boostFactor; Explanation functionExplanation = functionProvider.function(reader).explain(doc, subQueryExpl);
Explanation res = new ComplexExplanation( float sc = getValue() * functionExplanation.getValue();
true, sc, CustomBoostFactorQuery.this.toString() + ", product of:"); Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
res.addDetail(subQueryExpl); res.addDetail(functionExplanation);
res.addDetail(new Explanation(boostFactor, "boostFactor")); res.addDetail(new Explanation(getValue(), "queryBoost"));
return res; return res;
} }
} }
private class CustomBoostFactorScorer extends Scorer { private class CustomBoostFactorScorer extends Scorer {
private final CustomBoostFactorWeight weight;
private final float subQueryWeight; private final float subQueryWeight;
private final Scorer scorer; private final Scorer scorer;
private final IndexReader reader; private final Function function;
private CustomBoostFactorScorer(Similarity similarity, IndexReader reader, CustomBoostFactorWeight w, private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer, Function function) throws IOException {
Scorer scorer) throws IOException {
super(similarity); super(similarity);
this.weight = w;
this.subQueryWeight = w.getValue(); this.subQueryWeight = w.getValue();
this.scorer = scorer; this.scorer = scorer;
this.reader = reader; this.function = function;
} }
@Override @Override
@ -158,52 +154,28 @@ public class CustomBoostFactorQuery extends Query {
@Override @Override
public float score() throws IOException { public float score() throws IOException {
float score = subQueryWeight * scorer.score() * boostFactor; return subQueryWeight * function.score(scorer.docID(), scorer.score());
// Current Lucene priority queues can't handle NaN and -Infinity, so
// map to -Float.MAX_VALUE. This conditional handles both -infinity
// and NaN since comparisons with NaN are always false.
return score > Float.NEGATIVE_INFINITY ? score : -Float.MAX_VALUE;
}
public Explanation explain(int doc) throws IOException {
Explanation subQueryExpl = weight.subQueryWeight.explain(reader, doc);
if (!subQueryExpl.isMatch()) {
return subQueryExpl;
}
float sc = subQueryExpl.getValue() * boostFactor;
Explanation res = new ComplexExplanation(
true, sc, CustomBoostFactorQuery.this.toString() + ", product of:");
res.addDetail(subQueryExpl);
res.addDetail(new Explanation(boostFactor, "boostFactor"));
return res;
} }
} }
public String toString(String field) { public String toString(String field) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append("CustomBoostFactor(").append(subQuery.toString(field)).append(',').append(boostFactor).append(')'); sb.append("custom score (").append(subQuery.toString(field)).append(",function=").append(functionProvider).append(')');
sb.append(ToStringUtils.boost(getBoost())); sb.append(ToStringUtils.boost(getBoost()));
return sb.toString(); return sb.toString();
} }
public boolean equals(Object o) { public boolean equals(Object o) {
if (getClass() != o.getClass()) return false; if (getClass() != o.getClass()) return false;
CustomBoostFactorQuery other = (CustomBoostFactorQuery) o; FunctionScoreQuery other = (FunctionScoreQuery) o;
return this.getBoost() == other.getBoost() return this.getBoost() == other.getBoost()
&& this.subQuery.equals(other.subQuery) && this.subQuery.equals(other.subQuery)
&& this.boostFactor == other.boostFactor; && this.functionProvider.equals(other.functionProvider);
} }
public int hashCode() { public int hashCode() {
int h = subQuery.hashCode(); return subQuery.hashCode() + 31 * functionProvider.hashCode() ^ Float.floatToIntBits(getBoost());
h ^= (h << 17) | (h >>> 16);
h += Float.floatToIntBits(boostFactor);
h ^= (h << 8) | (h >>> 25);
h += Float.floatToIntBits(getBoost());
return h;
} }
} }

View File

@ -31,6 +31,8 @@ import org.elasticsearch.index.engine.robin.RobinIndexEngine;
import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.IndexQueryParser; import org.elasticsearch.index.query.IndexQueryParser;
import org.elasticsearch.util.lucene.search.*; import org.elasticsearch.util.lucene.search.*;
import org.elasticsearch.util.lucene.search.function.BoostFactorFunctionProvider;
import org.elasticsearch.util.lucene.search.function.FunctionScoreQuery;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import java.io.IOException; import java.io.IOException;
@ -713,10 +715,10 @@ public class SimpleIndexQueryParserTests {
@Test public void testCustomBoostFactorQueryBuilder() throws IOException { @Test public void testCustomBoostFactorQueryBuilder() throws IOException {
IndexQueryParser queryParser = newQueryParser(); IndexQueryParser queryParser = newQueryParser();
Query parsedQuery = queryParser.parse(customBoostFactorQuery(termQuery("name.last", "banon")).boostFactor(1.3f)); Query parsedQuery = queryParser.parse(customBoostFactorQuery(termQuery("name.last", "banon")).boostFactor(1.3f));
assertThat(parsedQuery, instanceOf(CustomBoostFactorQuery.class)); assertThat(parsedQuery, instanceOf(FunctionScoreQuery.class));
CustomBoostFactorQuery customBoostFactorQuery = (CustomBoostFactorQuery) parsedQuery; FunctionScoreQuery functionScoreQuery = (FunctionScoreQuery) parsedQuery;
assertThat(((TermQuery) customBoostFactorQuery.getSubQuery()).getTerm(), equalTo(new Term("name.last", "banon"))); assertThat(((TermQuery) functionScoreQuery.getSubQuery()).getTerm(), equalTo(new Term("name.last", "banon")));
assertThat((double) customBoostFactorQuery.getBoostFactor(), closeTo(1.3, 0.001)); assertThat((double) ((BoostFactorFunctionProvider) functionScoreQuery.getFunctionProvider()).getBoost(), closeTo(1.3, 0.001));
} }
@ -724,10 +726,10 @@ public class SimpleIndexQueryParserTests {
IndexQueryParser queryParser = newQueryParser(); IndexQueryParser queryParser = newQueryParser();
String query = copyToStringFromClasspath("/org/elasticsearch/index/query/xcontent/custom-boost-factor-query.json"); String query = copyToStringFromClasspath("/org/elasticsearch/index/query/xcontent/custom-boost-factor-query.json");
Query parsedQuery = queryParser.parse(query); Query parsedQuery = queryParser.parse(query);
assertThat(parsedQuery, instanceOf(CustomBoostFactorQuery.class)); assertThat(parsedQuery, instanceOf(FunctionScoreQuery.class));
CustomBoostFactorQuery customBoostFactorQuery = (CustomBoostFactorQuery) parsedQuery; FunctionScoreQuery functionScoreQuery = (FunctionScoreQuery) parsedQuery;
assertThat(((TermQuery) customBoostFactorQuery.getSubQuery()).getTerm(), equalTo(new Term("name.last", "banon"))); assertThat(((TermQuery) functionScoreQuery.getSubQuery()).getTerm(), equalTo(new Term("name.last", "banon")));
assertThat((double) customBoostFactorQuery.getBoostFactor(), closeTo(1.3, 0.001)); assertThat((double) ((BoostFactorFunctionProvider) functionScoreQuery.getFunctionProvider()).getBoost(), closeTo(1.3, 0.001));
} }
@Test public void testSpanTermQueryBuilder() throws IOException { @Test public void testSpanTermQueryBuilder() throws IOException {