support for custom script based sorting

This commit is contained in:
kimchy 2010-08-03 18:14:02 +03:00
parent 3d31c38f11
commit 91c256361e
6 changed files with 398 additions and 8 deletions

View File

@ -34,6 +34,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.facets.AbstractFacetBuilder;
import org.elasticsearch.search.highlight.HighlightBuilder;
import javax.annotation.Nullable;
import java.util.Map;
/**
@ -266,6 +267,30 @@ public class SearchRequestBuilder extends BaseRequestBuilder<SearchRequest, Sear
return this;
}
/**
* Adds a sort script.
*
* @param script The script to execute.
* @param type The type of the result (can either be "string" or "number").
* @param order The order.
*/
public SearchRequestBuilder addSortScript(String script, String type, SearchSourceBuilder.Order order) {
return addSortScript(script, type, order, null);
}
/**
* Adds a sort script.
*
* @param script The script to execute.
* @param type The type of the result (can either be "string" or "number").
* @param order The order.
* @param params Optional parameters to the script.
*/
public SearchRequestBuilder addSortScript(String script, String type, SearchSourceBuilder.Order order, @Nullable Map<String, Object> params) {
sourceBuilder().sortScript(script, type, order, params);
return this;
}
/**
* Adds the fields to load and return as part of the search request. If none are specified,
* the source of the document will be returned.

View File

@ -0,0 +1,107 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Elastic Search licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.field.function.sort;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldComparatorSource;
import org.elasticsearch.index.field.function.FieldsFunction;
import java.io.IOException;
import java.util.Map;
/**
* @author kimchy (shay.banon)
*/
// LUCENE MONITOR: Monitor against FieldComparator.Double
public class DoubleFieldsFunctionDataComparator extends FieldComparator {
public static FieldComparatorSource comparatorSource(FieldsFunction fieldsFunction, Map<String, Object> params) {
return new InnerSource(fieldsFunction, params);
}
private static class InnerSource extends FieldComparatorSource {
private final FieldsFunction fieldsFunction;
private final Map<String, Object> params;
private InnerSource(FieldsFunction fieldsFunction, Map<String, Object> params) {
this.fieldsFunction = fieldsFunction;
this.params = params;
}
@Override public FieldComparator newComparator(String fieldname, int numHits, int sortPos, boolean reversed) throws IOException {
return new DoubleFieldsFunctionDataComparator(numHits, fieldsFunction, params);
}
}
private final FieldsFunction fieldsFunction;
private final Map<String, Object> params;
private final double[] values;
private double bottom;
public DoubleFieldsFunctionDataComparator(int numHits, FieldsFunction fieldsFunction, Map<String, Object> params) {
this.fieldsFunction = fieldsFunction;
this.params = params;
values = new double[numHits];
}
@Override public void setNextReader(IndexReader reader, int docBase) throws IOException {
fieldsFunction.setNextReader(reader);
}
@Override public int compare(int slot1, int slot2) {
final double v1 = values[slot1];
final double v2 = values[slot2];
if (v1 > v2) {
return 1;
} else if (v1 < v2) {
return -1;
} else {
return 0;
}
}
@Override public int compareBottom(int doc) {
final double v2 = ((Number) fieldsFunction.execute(doc, params)).doubleValue();
if (bottom > v2) {
return 1;
} else if (bottom < v2) {
return -1;
} else {
return 0;
}
}
@Override public void copy(int slot, int doc) {
values[slot] = ((Number) fieldsFunction.execute(doc, params)).doubleValue();
}
@Override public void setBottom(final int bottom) {
this.bottom = values[bottom];
}
@Override public Comparable value(int slot) {
return Double.valueOf(values[slot]);
}
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Elastic Search licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.field.function.sort;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldComparatorSource;
import org.elasticsearch.index.field.function.FieldsFunction;
import java.io.IOException;
import java.util.Map;
/**
* @author kimchy (shay.banon)
*/
public class StringFieldsFunctionDataComparator extends FieldComparator {
public static FieldComparatorSource comparatorSource(FieldsFunction fieldsFunction, Map<String, Object> params) {
return new InnerSource(fieldsFunction, params);
}
private static class InnerSource extends FieldComparatorSource {
private final FieldsFunction fieldsFunction;
private final Map<String, Object> params;
private InnerSource(FieldsFunction fieldsFunction, Map<String, Object> params) {
this.fieldsFunction = fieldsFunction;
this.params = params;
}
@Override public FieldComparator newComparator(String fieldname, int numHits, int sortPos, boolean reversed) throws IOException {
return new StringFieldsFunctionDataComparator(numHits, fieldsFunction, params);
}
}
private final FieldsFunction fieldsFunction;
private final Map<String, Object> params;
private String[] values;
private String bottom;
public StringFieldsFunctionDataComparator(int numHits, FieldsFunction fieldsFunction, Map<String, Object> params) {
this.fieldsFunction = fieldsFunction;
this.params = params;
values = new String[numHits];
}
@Override public void setNextReader(IndexReader reader, int docBase) throws IOException {
fieldsFunction.setNextReader(reader);
}
@Override public int compare(int slot1, int slot2) {
final String val1 = values[slot1];
final String val2 = values[slot2];
if (val1 == null) {
if (val2 == null) {
return 0;
}
return -1;
} else if (val2 == null) {
return 1;
}
return val1.compareTo(val2);
}
@Override public int compareBottom(int doc) {
final String val2 = fieldsFunction.execute(doc, params).toString();
if (bottom == null) {
if (val2 == null) {
return 0;
}
return -1;
} else if (val2 == null) {
return 1;
}
return bottom.compareTo(val2);
}
@Override public void copy(int slot, int doc) {
values[slot] = fieldsFunction.execute(doc, params).toString();
}
@Override public void setBottom(final int bottom) {
this.bottom = values[bottom];
}
@Override public Comparable value(int slot) {
return values[slot];
}
}

View File

@ -33,6 +33,7 @@ import org.elasticsearch.search.facets.AbstractFacetBuilder;
import org.elasticsearch.search.highlight.HighlightBuilder;
import org.elasticsearch.search.query.SortParseElement;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@ -82,6 +83,8 @@ public class SearchSourceBuilder implements ToXContent {
private List<SortTuple> sortFields;
private List<ScriptSortTuple> sortScripts;
private List<String> fieldNames;
private List<ScriptField> scriptFields;
@ -179,6 +182,22 @@ public class SearchSourceBuilder implements ToXContent {
return sort(name, false);
}
/**
* Adds a sort script.
*
* @param script The script to execute.
* @param type The type of the result (can either be "string" or "number").
* @param order The order.
* @param params Optional parameters to the script.
*/
public SearchSourceBuilder sortScript(String script, String type, Order order, @Nullable Map<String, Object> params) {
if (sortScripts == null) {
sortScripts = Lists.newArrayList();
}
sortScripts.add(new ScriptSortTuple(script, type, params, order == Order.DESC));
return this;
}
/**
* Add a sort against the given field name and if it should be revered or not.
*
@ -359,16 +378,33 @@ public class SearchSourceBuilder implements ToXContent {
builder.endObject();
}
if (sortFields != null) {
if (sortFields != null || sortScripts != null) {
builder.field("sort");
builder.startObject();
for (SortTuple sortTuple : sortFields) {
builder.field(sortTuple.fieldName());
builder.startObject();
if (sortTuple.reverse()) {
builder.field("reverse", true);
if (sortFields != null) {
for (SortTuple sortTuple : sortFields) {
builder.field(sortTuple.fieldName());
builder.startObject();
if (sortTuple.reverse()) {
builder.field("reverse", true);
}
builder.endObject();
}
}
if (sortScripts != null) {
for (ScriptSortTuple scriptSort : sortScripts) {
builder.startObject("_script");
builder.field("script", scriptSort.script());
builder.field("type", scriptSort.type());
if (scriptSort.params() != null) {
builder.field("params");
builder.map(scriptSort.params());
}
if (scriptSort.reverse()) {
builder.field("reverse", true);
}
builder.endObject();
}
builder.endObject();
}
builder.endObject();
}
@ -439,4 +475,34 @@ public class SearchSourceBuilder implements ToXContent {
return reverse;
}
}
private static class ScriptSortTuple {
private final String script;
private final String type;
private final Map<String, Object> params;
private final boolean reverse;
private ScriptSortTuple(String script, String type, Map<String, Object> params, boolean reverse) {
this.script = script;
this.type = type;
this.params = params;
this.reverse = reverse;
}
public String script() {
return script;
}
public String type() {
return type;
}
public Map<String, Object> params() {
return params;
}
public boolean reverse() {
return reverse;
}
}
}

View File

@ -19,10 +19,15 @@
package org.elasticsearch.search.query;
import org.apache.lucene.search.FieldComparatorSource;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.elasticsearch.common.collect.Lists;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.field.function.FieldsFunction;
import org.elasticsearch.index.field.function.script.ScriptFieldsFunction;
import org.elasticsearch.index.field.function.sort.DoubleFieldsFunctionDataComparator;
import org.elasticsearch.index.field.function.sort.StringFieldsFunctionDataComparator;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.SearchParseException;
@ -30,6 +35,7 @@ import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
/**
* @author kimchy (shay.banon)
@ -41,6 +47,7 @@ public class SortParseElement implements SearchParseElement {
private static final SortField SORT_DOC = new SortField(null, SortField.DOC);
private static final SortField SORT_DOC_REVERSE = new SortField(null, SortField.DOC, true);
public static final String SCRIPT_FIELD_NAME = "_script";
public static final String SCORE_FIELD_NAME = "_score";
public static final String DOC_FIELD_NAME = "_doc";
@ -73,6 +80,9 @@ public class SortParseElement implements SearchParseElement {
String fieldName = parser.currentName();
boolean reverse = false;
String innerJsonName = null;
String script = null;
String type = null;
Map<String, Object> params = null;
token = parser.nextToken();
if (token == XContentParser.Token.VALUE_STRING) {
String direction = parser.text();
@ -88,11 +98,42 @@ public class SortParseElement implements SearchParseElement {
} else if (token.isValue()) {
if ("reverse".equals(innerJsonName)) {
reverse = parser.booleanValue();
} else if ("order".equals(innerJsonName)) {
if ("asc".equals(parser.text())) {
reverse = SCORE_FIELD_NAME.equals(fieldName);
} else if ("desc".equals(parser.text())) {
reverse = !SCORE_FIELD_NAME.equals(fieldName);
}
} else if ("script".equals(innerJsonName)) {
script = parser.text();
} else if ("type".equals(innerJsonName)) {
type = parser.text();
} else if ("params".equals(innerJsonName)) {
params = parser.map();
}
}
}
}
addSortField(context, sortFields, fieldName, reverse);
if (SCRIPT_FIELD_NAME.equals(fieldName)) {
if (script == null) {
throw new SearchParseException(context, "_script sorting requires setting the script to sort by");
}
if (type == null) {
throw new SearchParseException(context, "_script sorting requires setting the type of the script");
}
FieldsFunction fieldsFunction = new ScriptFieldsFunction(script, context.scriptService(), context.mapperService(), context.fieldDataCache());
FieldComparatorSource fieldComparatorSource;
if ("string".equals(type)) {
fieldComparatorSource = StringFieldsFunctionDataComparator.comparatorSource(fieldsFunction, params);
} else if ("number".equals(type)) {
fieldComparatorSource = DoubleFieldsFunctionDataComparator.comparatorSource(fieldsFunction, params);
} else {
throw new SearchParseException(context, "custom script sort type [" + type + "] not supported");
}
sortFields.add(new SortField(fieldName, fieldComparatorSource, reverse));
} else {
addSortField(context, sortFields, fieldName, reverse);
}
}
}
}

View File

@ -90,6 +90,16 @@ public class SimpleSortTests extends AbstractNodesTests {
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")
.addSortScript("doc['svalue'].value", "string", SearchSourceBuilder.Order.ASC)
.execute().actionGet();
assertThat(searchResponse.hits().getTotalHits(), equalTo(2l));
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")
@ -100,6 +110,15 @@ public class SimpleSortTests extends AbstractNodesTests {
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")
.addSortScript("doc['svalue'].value", "string", SearchSourceBuilder.Order.DESC)
.execute().actionGet();
assertThat(searchResponse.hits().getTotalHits(), equalTo(2l));
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
@ -111,6 +130,16 @@ public class SimpleSortTests extends AbstractNodesTests {
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")
.addSortScript("doc['ivalue'].value", "number", SearchSourceBuilder.Order.ASC)
.execute().actionGet();
assertThat(searchResponse.hits().getTotalHits(), equalTo(2l));
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("1"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("2"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")
@ -121,6 +150,16 @@ public class SimpleSortTests extends AbstractNodesTests {
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")
.addSortScript("doc['ivalue'].value", "number", SearchSourceBuilder.Order.DESC)
.execute().actionGet();
assertThat(searchResponse.hits().getTotalHits(), equalTo(2l));
assertThat((String) searchResponse.hits().getAt(0).field("id").value(), equalTo("2"));
assertThat((String) searchResponse.hits().getAt(1).field("id").value(), equalTo("1"));
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addScriptField("id", "doc['id'].value")