From 4a886dbae1f3f78e1ac1224334e8a33ac11c5edb Mon Sep 17 00:00:00 2001 From: Shay Banon Date: Thu, 4 Aug 2011 02:50:58 +0300 Subject: [PATCH] Query DSL: custom_filters_score allow to associate boost on filter instead of script, closes #1204. --- .../query/CustomFiltersScoreQueryBuilder.java | 17 +++++++++++++- .../query/CustomFiltersScoreQueryParser.java | 23 +++++++++++++++---- .../customscore/CustomScoreSearchTests.java | 20 ++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java index eba7d2204bb..798e0e59a06 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryBuilder.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.query; import org.elasticsearch.common.collect.Maps; +import org.elasticsearch.common.trove.list.array.TFloatArrayList; import org.elasticsearch.common.xcontent.XContentBuilder; import java.io.IOException; @@ -43,6 +44,7 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder { private ArrayList filters = new ArrayList(); private ArrayList scripts = new ArrayList(); + private TFloatArrayList boosts = new TFloatArrayList(); public CustomFiltersScoreQueryBuilder(QueryBuilder queryBuilder) { this.queryBuilder = queryBuilder; @@ -51,6 +53,14 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder { public CustomFiltersScoreQueryBuilder add(FilterBuilder filter, String script) { this.filters.add(filter); this.scripts.add(script); + this.boosts.add(-1); + return this; + } + + public CustomFiltersScoreQueryBuilder add(FilterBuilder filter, float boost) { + this.filters.add(filter); + this.scripts.add(null); + this.boosts.add(boost); return this; } @@ -104,7 +114,12 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder { builder.startObject(); builder.field("filter"); filters.get(i).toXContent(builder, params); - builder.field("script", scripts.get(i)); + String script = scripts.get(i); + if (script != null) { + builder.field("script", script); + } else { + builder.field("boost", boosts.get(i)); + } builder.endObject(); } builder.endArray(); diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java index 905c771eb61..c4d59e1dbbe 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java @@ -24,7 +24,10 @@ import org.apache.lucene.search.Query; import org.elasticsearch.ElasticSearchIllegalStateException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.lucene.search.function.BoostScoreFunction; import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery; +import org.elasticsearch.common.lucene.search.function.ScoreFunction; +import org.elasticsearch.common.trove.list.array.TFloatArrayList; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.internal.SearchContext; @@ -57,6 +60,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser { ArrayList filters = new ArrayList(); ArrayList scripts = new ArrayList(); + TFloatArrayList boosts = new TFloatArrayList(); String currentFieldName = null; XContentParser.Token token; @@ -74,6 +78,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser { while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { String script = null; Filter filter = null; + float fboost = Float.NaN; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); @@ -84,17 +89,20 @@ public class CustomFiltersScoreQueryParser implements QueryParser { } else if (token.isValue()) { if ("script".equals(currentFieldName)) { script = parser.text(); + } else if ("boost".equals(currentFieldName)) { + fboost = parser.floatValue(); } } } - if (script == null) { - throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'script' in filters array element"); + if (script == null && fboost == -1) { + throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'script' or 'boost' in filters array element"); } if (filter == null) { throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'filter' in filters array element"); } filters.add(filter); scripts.add(script); + boosts.add(fboost); } } } else if (token.isValue()) { @@ -118,8 +126,15 @@ public class CustomFiltersScoreQueryParser implements QueryParser { } FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[filters.size()]; for (int i = 0; i < filterFunctions.length; i++) { - SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, scripts.get(i), vars); - filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), new CustomScoreQueryParser.ScriptScoreFunction(searchScript)); + ScoreFunction scoreFunction; + String script = scripts.get(i); + if (script != null) { + SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars); + scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(searchScript); + } else { + scoreFunction = new BoostScoreFunction(boosts.get(i)); + } + filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction); } FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions); functionScoreQuery.setBoost(boost); diff --git a/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java b/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java index 0c0b25b6e2d..7439beea4a9 100644 --- a/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java +++ b/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java @@ -174,5 +174,25 @@ public class CustomScoreSearchTests extends AbstractNodesTests { assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f)); assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3"))); assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f)); + + searchResponse = client.prepareSearch("test") + .setQuery(customFiltersScoreQuery(matchAllQuery()) + .add(termFilter("field", "value4"), 2) + .add(termFilter("field", "value2"), 3)) + .setExplain(true) + .execute().actionGet(); + + assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0)); + + assertThat(searchResponse.hits().totalHits(), equalTo(4l)); + assertThat(searchResponse.hits().getAt(0).id(), equalTo("2")); + assertThat(searchResponse.hits().getAt(0).score(), equalTo(3.0f)); + logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation()); + assertThat(searchResponse.hits().getAt(1).id(), equalTo("4")); + assertThat(searchResponse.hits().getAt(1).score(), equalTo(2.0f)); + assertThat(searchResponse.hits().getAt(2).id(), anyOf(equalTo("1"), equalTo("3"))); + assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f)); + assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3"))); + assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f)); } } \ No newline at end of file