add boost_mode to rest interface

allow user to set combine functions explicitely via boost_mode variable.
This commit is contained in:
Britta Weber 2013-08-19 16:07:38 +02:00
parent b007af1f46
commit 634f1036a0
3 changed files with 85 additions and 0 deletions

View File

@ -44,6 +44,8 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost
private Float maxBoost;
private String scoreMode;
private String boostMode;
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
private ArrayList<ScoreFunctionBuilder> scoreFunctions = new ArrayList<ScoreFunctionBuilder>();
@ -74,6 +76,11 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost
this.scoreMode = scoreMode;
return this;
}
public FunctionScoreQueryBuilder boostMode(String boostMode) {
this.boostMode = boostMode;
return this;
}
public FunctionScoreQueryBuilder maxBoost(float maxBoost) {
this.maxBoost = maxBoost;
@ -124,6 +131,9 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost
if (scoreMode != null) {
builder.field("score_mode", scoreMode);
}
if (boostMode != null) {
builder.field("boost_mode", boostMode);
}
if (maxBoost != null) {
builder.field("max_boost", maxBoost);
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.lucene.search.MatchAllDocsFilter;
import org.elasticsearch.common.lucene.search.XConstantScoreQuery;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
@ -67,6 +68,7 @@ public class FunctionScoreQueryParser implements QueryParser {
String currentFieldName = null;
XContentParser.Token token;
CombineFunction combineFunction = null;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
@ -76,6 +78,8 @@ public class FunctionScoreQueryParser implements QueryParser {
query = new XConstantScoreQuery(parseContext.parseInnerFilter());
} else if ("score_mode".equals(currentFieldName) || "scoreMode".equals(currentFieldName)) {
scoreMode = parseScoreMode(parseContext, parser);
} else if ("boost_mode".equals(currentFieldName) || "boostMode".equals(currentFieldName)) {
combineFunction = parseBoostMode(parseContext, parser);
} else if ("max_boost".equals(currentFieldName) || "maxBoost".equals(currentFieldName)) {
maxBoost = parser.floatValue();
} else if ("boost".equals(currentFieldName)) {
@ -101,6 +105,9 @@ public class FunctionScoreQueryParser implements QueryParser {
// provided. In this case we create a FunctionScoreQuery.
if (filterFunctions.size() == 1 && filterFunctions.get(0).filter == null) {
FunctionScoreQuery theQuery = new FunctionScoreQuery(query, filterFunctions.get(0).function);
if (combineFunction != null) {
theQuery.setCombineFunction(combineFunction);
}
theQuery.setBoost(boost);
theQuery.setMaxBoost(maxBoost);
return theQuery;
@ -108,6 +115,9 @@ public class FunctionScoreQueryParser implements QueryParser {
} else {
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode,
filterFunctions.toArray(new FiltersFunctionScoreQuery.FilterFunction[filterFunctions.size()]), maxBoost);
if (combineFunction != null) {
functionScoreQuery.setCombineFunction(combineFunction);
}
functionScoreQuery.setBoost(boost);
return functionScoreQuery;
}
@ -166,4 +176,14 @@ public class FunctionScoreQueryParser implements QueryParser {
throw new QueryParsingException(parseContext.index(), NAME + " illegal score_mode [" + scoreMode + "]");
}
}
private CombineFunction parseBoostMode(QueryParseContext parseContext, XContentParser parser) throws IOException {
String boostMode = parser.text();
for (CombineFunction cf : CombineFunction.values()) {
if (cf.getName().equals(boostMode)) {
return cf;
}
}
throw new QueryParsingException(parseContext.index(), NAME + " illegal boost_mode [" + boostMode + "]");
}
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.index.query.MatchAllFilterBuilder;
import org.elasticsearch.index.query.functionscore.DecayFunctionBuilder;
import org.elasticsearch.index.query.functionscore.exp.ExponentialDecayFunctionBuilder;
@ -147,6 +148,60 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
}
@Test
public void testBoostModeSettingWorks() throws Exception {
createIndexMapped("test", "type1", "test", "string", "loc", "geo_point");
ensureYellow();
List<IndexRequestBuilder> indexBuilders = new ArrayList<IndexRequestBuilder>();
indexBuilders.add(new IndexRequestBuilder(client())
.setType("type1")
.setId("1")
.setIndex("test")
.setSource(
jsonBuilder().startObject().field("test", "value").startObject("loc").field("lat", 11).field("lon", 21).endObject()
.endObject()));
indexBuilders.add(new IndexRequestBuilder(client())
.setType("type1")
.setId("2")
.setIndex("test")
.setSource(
jsonBuilder().startObject().field("test", "value value").startObject("loc").field("lat", 11).field("lon", 20).endObject()
.endObject()));
IndexRequestBuilder[] builders = indexBuilders.toArray(new IndexRequestBuilder[indexBuilders.size()]);
indexRandom("test", false, builders);
refresh();
// Test Gauss
List<Float> lonlat = new ArrayList<Float>();
lonlat.add(new Float(20));
lonlat.add(new Float(11));
DecayFunctionBuilder fb = new GaussDecayFunctionBuilder("loc", lonlat, "1000km");
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true).query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.MULT.getName()))));
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (2)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
// Test Exp
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true).query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.PLAIN.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (2)));
assertThat(sh.getAt(0).getId(), equalTo("2"));
assertThat(sh.getAt(1).getId(), equalTo("1"));
}
@Test(expected = SearchPhaseExecutionException.class)
public void testExceptionThrownIfScaleLE0() throws Exception {