From 91c256361ec587a950d499f24306a37a07c513df Mon Sep 17 00:00:00 2001 From: kimchy Date: Tue, 3 Aug 2010 18:14:02 +0300 Subject: [PATCH] support for custom script based sorting --- .../action/search/SearchRequestBuilder.java | 25 ++++ .../DoubleFieldsFunctionDataComparator.java | 107 +++++++++++++++++ .../StringFieldsFunctionDataComparator.java | 112 ++++++++++++++++++ .../search/builder/SearchSourceBuilder.java | 80 +++++++++++-- .../search/query/SortParseElement.java | 43 ++++++- .../search/sort/SimpleSortTests.java | 39 ++++++ 6 files changed, 398 insertions(+), 8 deletions(-) create mode 100644 modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/DoubleFieldsFunctionDataComparator.java create mode 100644 modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/StringFieldsFunctionDataComparator.java diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/client/action/search/SearchRequestBuilder.java b/modules/elasticsearch/src/main/java/org/elasticsearch/client/action/search/SearchRequestBuilder.java index 1d4747714db..29620d37b41 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/client/action/search/SearchRequestBuilder.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/client/action/search/SearchRequestBuilder.java @@ -34,6 +34,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.facets.AbstractFacetBuilder; import org.elasticsearch.search.highlight.HighlightBuilder; +import javax.annotation.Nullable; import java.util.Map; /** @@ -266,6 +267,30 @@ public class SearchRequestBuilder extends BaseRequestBuilder params) { + sourceBuilder().sortScript(script, type, order, params); + return this; + } + /** * Adds the fields to load and return as part of the search request. If none are specified, * the source of the document will be returned. diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/DoubleFieldsFunctionDataComparator.java b/modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/DoubleFieldsFunctionDataComparator.java new file mode 100644 index 00000000000..9384f50e43d --- /dev/null +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/DoubleFieldsFunctionDataComparator.java @@ -0,0 +1,107 @@ +/* + * Licensed to Elastic Search and Shay Banon under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Elastic Search licenses this + * file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.field.function.sort; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldComparatorSource; +import org.elasticsearch.index.field.function.FieldsFunction; + +import java.io.IOException; +import java.util.Map; + +/** + * @author kimchy (shay.banon) + */ +// LUCENE MONITOR: Monitor against FieldComparator.Double +public class DoubleFieldsFunctionDataComparator extends FieldComparator { + + public static FieldComparatorSource comparatorSource(FieldsFunction fieldsFunction, Map params) { + return new InnerSource(fieldsFunction, params); + } + + private static class InnerSource extends FieldComparatorSource { + + private final FieldsFunction fieldsFunction; + + private final Map params; + + private InnerSource(FieldsFunction fieldsFunction, Map params) { + this.fieldsFunction = fieldsFunction; + this.params = params; + } + + @Override public FieldComparator newComparator(String fieldname, int numHits, int sortPos, boolean reversed) throws IOException { + return new DoubleFieldsFunctionDataComparator(numHits, fieldsFunction, params); + } + } + + private final FieldsFunction fieldsFunction; + + private final Map params; + + private final double[] values; + private double bottom; + + public DoubleFieldsFunctionDataComparator(int numHits, FieldsFunction fieldsFunction, Map params) { + this.fieldsFunction = fieldsFunction; + this.params = params; + values = new double[numHits]; + } + + @Override public void setNextReader(IndexReader reader, int docBase) throws IOException { + fieldsFunction.setNextReader(reader); + } + + @Override public int compare(int slot1, int slot2) { + final double v1 = values[slot1]; + final double v2 = values[slot2]; + if (v1 > v2) { + return 1; + } else if (v1 < v2) { + return -1; + } else { + return 0; + } + } + + @Override public int compareBottom(int doc) { + final double v2 = ((Number) fieldsFunction.execute(doc, params)).doubleValue(); + if (bottom > v2) { + return 1; + } else if (bottom < v2) { + return -1; + } else { + return 0; + } + } + + @Override public void copy(int slot, int doc) { + values[slot] = ((Number) fieldsFunction.execute(doc, params)).doubleValue(); + } + + @Override public void setBottom(final int bottom) { + this.bottom = values[bottom]; + } + + @Override public Comparable value(int slot) { + return Double.valueOf(values[slot]); + } +} diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/StringFieldsFunctionDataComparator.java b/modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/StringFieldsFunctionDataComparator.java new file mode 100644 index 00000000000..fd7648382ee --- /dev/null +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/index/field/function/sort/StringFieldsFunctionDataComparator.java @@ -0,0 +1,112 @@ +/* + * Licensed to Elastic Search and Shay Banon under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Elastic Search licenses this + * file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.field.function.sort; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldComparatorSource; +import org.elasticsearch.index.field.function.FieldsFunction; + +import java.io.IOException; +import java.util.Map; + +/** + * @author kimchy (shay.banon) + */ +public class StringFieldsFunctionDataComparator extends FieldComparator { + + public static FieldComparatorSource comparatorSource(FieldsFunction fieldsFunction, Map params) { + return new InnerSource(fieldsFunction, params); + } + + private static class InnerSource extends FieldComparatorSource { + + private final FieldsFunction fieldsFunction; + + private final Map params; + + private InnerSource(FieldsFunction fieldsFunction, Map params) { + this.fieldsFunction = fieldsFunction; + this.params = params; + } + + @Override public FieldComparator newComparator(String fieldname, int numHits, int sortPos, boolean reversed) throws IOException { + return new StringFieldsFunctionDataComparator(numHits, fieldsFunction, params); + } + } + + private final FieldsFunction fieldsFunction; + + private final Map params; + + private String[] values; + + private String bottom; + + public StringFieldsFunctionDataComparator(int numHits, FieldsFunction fieldsFunction, Map params) { + this.fieldsFunction = fieldsFunction; + this.params = params; + values = new String[numHits]; + } + + @Override public void setNextReader(IndexReader reader, int docBase) throws IOException { + fieldsFunction.setNextReader(reader); + } + + @Override public int compare(int slot1, int slot2) { + final String val1 = values[slot1]; + final String val2 = values[slot2]; + if (val1 == null) { + if (val2 == null) { + return 0; + } + return -1; + } else if (val2 == null) { + return 1; + } + + return val1.compareTo(val2); + } + + @Override public int compareBottom(int doc) { + final String val2 = fieldsFunction.execute(doc, params).toString(); + if (bottom == null) { + if (val2 == null) { + return 0; + } + return -1; + } else if (val2 == null) { + return 1; + } + return bottom.compareTo(val2); + } + + @Override public void copy(int slot, int doc) { + values[slot] = fieldsFunction.execute(doc, params).toString(); + } + + @Override public void setBottom(final int bottom) { + this.bottom = values[bottom]; + } + + @Override public Comparable value(int slot) { + return values[slot]; + } +} diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/modules/elasticsearch/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index d1e31549b10..dd791d0f1be 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -33,6 +33,7 @@ import org.elasticsearch.search.facets.AbstractFacetBuilder; import org.elasticsearch.search.highlight.HighlightBuilder; import org.elasticsearch.search.query.SortParseElement; +import javax.annotation.Nullable; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -82,6 +83,8 @@ public class SearchSourceBuilder implements ToXContent { private List sortFields; + private List sortScripts; + private List fieldNames; private List scriptFields; @@ -179,6 +182,22 @@ public class SearchSourceBuilder implements ToXContent { return sort(name, false); } + /** + * Adds a sort script. + * + * @param script The script to execute. + * @param type The type of the result (can either be "string" or "number"). + * @param order The order. + * @param params Optional parameters to the script. + */ + public SearchSourceBuilder sortScript(String script, String type, Order order, @Nullable Map params) { + if (sortScripts == null) { + sortScripts = Lists.newArrayList(); + } + sortScripts.add(new ScriptSortTuple(script, type, params, order == Order.DESC)); + return this; + } + /** * Add a sort against the given field name and if it should be revered or not. * @@ -359,16 +378,33 @@ public class SearchSourceBuilder implements ToXContent { builder.endObject(); } - if (sortFields != null) { + if (sortFields != null || sortScripts != null) { builder.field("sort"); builder.startObject(); - for (SortTuple sortTuple : sortFields) { - builder.field(sortTuple.fieldName()); - builder.startObject(); - if (sortTuple.reverse()) { - builder.field("reverse", true); + if (sortFields != null) { + for (SortTuple sortTuple : sortFields) { + builder.field(sortTuple.fieldName()); + builder.startObject(); + if (sortTuple.reverse()) { + builder.field("reverse", true); + } + builder.endObject(); + } + } + if (sortScripts != null) { + for (ScriptSortTuple scriptSort : sortScripts) { + builder.startObject("_script"); + builder.field("script", scriptSort.script()); + builder.field("type", scriptSort.type()); + if (scriptSort.params() != null) { + builder.field("params"); + builder.map(scriptSort.params()); + } + if (scriptSort.reverse()) { + builder.field("reverse", true); + } + builder.endObject(); } - builder.endObject(); } builder.endObject(); } @@ -439,4 +475,34 @@ public class SearchSourceBuilder implements ToXContent { return reverse; } } + + private static class ScriptSortTuple { + private final String script; + private final String type; + private final Map params; + private final boolean reverse; + + private ScriptSortTuple(String script, String type, Map params, boolean reverse) { + this.script = script; + this.type = type; + this.params = params; + this.reverse = reverse; + } + + public String script() { + return script; + } + + public String type() { + return type; + } + + public Map params() { + return params; + } + + public boolean reverse() { + return reverse; + } + } } diff --git a/modules/elasticsearch/src/main/java/org/elasticsearch/search/query/SortParseElement.java b/modules/elasticsearch/src/main/java/org/elasticsearch/search/query/SortParseElement.java index ea0ceec2bc1..00d82b3b012 100644 --- a/modules/elasticsearch/src/main/java/org/elasticsearch/search/query/SortParseElement.java +++ b/modules/elasticsearch/src/main/java/org/elasticsearch/search/query/SortParseElement.java @@ -19,10 +19,15 @@ package org.elasticsearch.search.query; +import org.apache.lucene.search.FieldComparatorSource; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.elasticsearch.common.collect.Lists; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.field.function.FieldsFunction; +import org.elasticsearch.index.field.function.script.ScriptFieldsFunction; +import org.elasticsearch.index.field.function.sort.DoubleFieldsFunctionDataComparator; +import org.elasticsearch.index.field.function.sort.StringFieldsFunctionDataComparator; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.search.SearchParseElement; import org.elasticsearch.search.SearchParseException; @@ -30,6 +35,7 @@ import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; import java.util.List; +import java.util.Map; /** * @author kimchy (shay.banon) @@ -41,6 +47,7 @@ public class SortParseElement implements SearchParseElement { private static final SortField SORT_DOC = new SortField(null, SortField.DOC); private static final SortField SORT_DOC_REVERSE = new SortField(null, SortField.DOC, true); + public static final String SCRIPT_FIELD_NAME = "_script"; public static final String SCORE_FIELD_NAME = "_score"; public static final String DOC_FIELD_NAME = "_doc"; @@ -73,6 +80,9 @@ public class SortParseElement implements SearchParseElement { String fieldName = parser.currentName(); boolean reverse = false; String innerJsonName = null; + String script = null; + String type = null; + Map params = null; token = parser.nextToken(); if (token == XContentParser.Token.VALUE_STRING) { String direction = parser.text(); @@ -88,11 +98,42 @@ public class SortParseElement implements SearchParseElement { } else if (token.isValue()) { if ("reverse".equals(innerJsonName)) { reverse = parser.booleanValue(); + } else if ("order".equals(innerJsonName)) { + if ("asc".equals(parser.text())) { + reverse = SCORE_FIELD_NAME.equals(fieldName); + } else if ("desc".equals(parser.text())) { + reverse = !SCORE_FIELD_NAME.equals(fieldName); + } + } else if ("script".equals(innerJsonName)) { + script = parser.text(); + } else if ("type".equals(innerJsonName)) { + type = parser.text(); + } else if ("params".equals(innerJsonName)) { + params = parser.map(); } } } } - addSortField(context, sortFields, fieldName, reverse); + if (SCRIPT_FIELD_NAME.equals(fieldName)) { + if (script == null) { + throw new SearchParseException(context, "_script sorting requires setting the script to sort by"); + } + if (type == null) { + throw new SearchParseException(context, "_script sorting requires setting the type of the script"); + } + FieldsFunction fieldsFunction = new ScriptFieldsFunction(script, context.scriptService(), context.mapperService(), context.fieldDataCache()); + FieldComparatorSource fieldComparatorSource; + if ("string".equals(type)) { + fieldComparatorSource = StringFieldsFunctionDataComparator.comparatorSource(fieldsFunction, params); + } else if ("number".equals(type)) { + fieldComparatorSource = DoubleFieldsFunctionDataComparator.comparatorSource(fieldsFunction, params); + } else { + throw new SearchParseException(context, "custom script sort type [" + type + "] not supported"); + } + sortFields.add(new SortField(fieldName, fieldComparatorSource, reverse)); + } else { + addSortField(context, sortFields, fieldName, reverse); + } } } } diff --git a/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/sort/SimpleSortTests.java b/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/sort/SimpleSortTests.java index 213596cd660..e59f9326861 100644 --- a/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/sort/SimpleSortTests.java +++ b/modules/test/integration/src/test/java/org/elasticsearch/test/integration/search/sort/SimpleSortTests.java @@ -90,6 +90,16 @@ public class SimpleSortTests extends AbstractNodesTests { assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1")); assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2")); + searchResponse = client.prepareSearch() + .setQuery(matchAllQuery()) + .addScriptField("id", "doc['id'].value") + .addSortScript("doc['svalue'].value", "string", SearchSourceBuilder.Order.ASC) + .execute().actionGet(); + + assertThat(searchResponse.hits().getTotalHits(), equalTo(2l)); + assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1")); + assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2")); + searchResponse = client.prepareSearch() .setQuery(matchAllQuery()) .addScriptField("id", "doc['id'].value") @@ -100,6 +110,15 @@ public class SimpleSortTests extends AbstractNodesTests { assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2")); assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1")); + searchResponse = client.prepareSearch() + .setQuery(matchAllQuery()) + .addScriptField("id", "doc['id'].value") + .addSortScript("doc['svalue'].value", "string", SearchSourceBuilder.Order.DESC) + .execute().actionGet(); + + assertThat(searchResponse.hits().getTotalHits(), equalTo(2l)); + assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2")); + assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1")); searchResponse = client.prepareSearch() .setQuery(matchAllQuery()) @@ -111,6 +130,16 @@ public class SimpleSortTests extends AbstractNodesTests { assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1")); assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2")); + searchResponse = client.prepareSearch() + .setQuery(matchAllQuery()) + .addScriptField("id", "doc['id'].value") + .addSortScript("doc['ivalue'].value", "number", SearchSourceBuilder.Order.ASC) + .execute().actionGet(); + + assertThat(searchResponse.hits().getTotalHits(), equalTo(2l)); + assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1")); + assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2")); + searchResponse = client.prepareSearch() .setQuery(matchAllQuery()) .addScriptField("id", "doc['id'].value") @@ -121,6 +150,16 @@ public class SimpleSortTests extends AbstractNodesTests { assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2")); assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1")); + searchResponse = client.prepareSearch() + .setQuery(matchAllQuery()) + .addScriptField("id", "doc['id'].value") + .addSortScript("doc['ivalue'].value", "number", SearchSourceBuilder.Order.DESC) + .execute().actionGet(); + + assertThat(searchResponse.hits().getTotalHits(), equalTo(2l)); + assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2")); + assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1")); + searchResponse = client.prepareSearch() .setQuery(matchAllQuery()) .addScriptField("id", "doc['id'].value")