Make SortBuilders pluggable (#1856)

Add the ability for plugin authors to add custom sort builders.

Signed-off-by: Matt Weber <matt@mattweber.org>
This commit is contained in:
Matt Weber 2022-01-14 09:06:13 -08:00 committed by GitHub
parent 6dcfe8cdcc
commit e7d44c20e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 355 additions and 30 deletions

View File

@ -0,0 +1,94 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.sort;
import static org.hamcrest.Matchers.equalTo;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.plugins.Plugin;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.plugin.CustomSortBuilder;
import org.opensearch.search.sort.plugin.CustomSortPlugin;
import org.opensearch.test.InternalSettingsPlugin;
import org.opensearch.test.OpenSearchIntegTestCase;
import java.util.Arrays;
import java.util.Collection;
public class SortFromPluginIT extends OpenSearchIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(CustomSortPlugin.class, InternalSettingsPlugin.class);
}
public void testPluginSort() throws Exception {
createIndex("test");
ensureGreen();
client().prepareIndex("test", "type", "1").setSource("field", 2).get();
client().prepareIndex("test", "type", "2").setSource("field", 1).get();
client().prepareIndex("test", "type", "3").setSource("field", 0).get();
refresh();
SearchResponse searchResponse = client().prepareSearch("test").addSort(new CustomSortBuilder("field", SortOrder.ASC)).get();
assertThat(searchResponse.getHits().getAt(0).getId(), equalTo("3"));
assertThat(searchResponse.getHits().getAt(1).getId(), equalTo("2"));
assertThat(searchResponse.getHits().getAt(2).getId(), equalTo("1"));
searchResponse = client().prepareSearch("test").addSort(new CustomSortBuilder("field", SortOrder.DESC)).get();
assertThat(searchResponse.getHits().getAt(0).getId(), equalTo("1"));
assertThat(searchResponse.getHits().getAt(1).getId(), equalTo("2"));
assertThat(searchResponse.getHits().getAt(2).getId(), equalTo("3"));
}
public void testPluginSortXContent() throws Exception {
createIndex("test");
ensureGreen();
client().prepareIndex("test", "type", "1").setSource("field", 2).get();
client().prepareIndex("test", "type", "2").setSource("field", 1).get();
client().prepareIndex("test", "type", "3").setSource("field", 0).get();
refresh();
// builder -> json -> builder
SearchResponse searchResponse = client().prepareSearch("test")
.setSource(
SearchSourceBuilder.fromXContent(
createParser(
JsonXContent.jsonXContent,
new SearchSourceBuilder().sort(new CustomSortBuilder("field", SortOrder.ASC)).toString()
)
)
)
.get();
assertThat(searchResponse.getHits().getAt(0).getId(), equalTo("3"));
assertThat(searchResponse.getHits().getAt(1).getId(), equalTo("2"));
assertThat(searchResponse.getHits().getAt(2).getId(), equalTo("1"));
searchResponse = client().prepareSearch("test")
.setSource(
SearchSourceBuilder.fromXContent(
createParser(
JsonXContent.jsonXContent,
new SearchSourceBuilder().sort(new CustomSortBuilder("field", SortOrder.DESC)).toString()
)
)
)
.get();
assertThat(searchResponse.getHits().getAt(0).getId(), equalTo("1"));
assertThat(searchResponse.getHits().getAt(1).getId(), equalTo("2"));
assertThat(searchResponse.getHits().getAt(2).getId(), equalTo("3"));
}
}

View File

@ -33,6 +33,7 @@
package org.opensearch.plugins; package org.opensearch.plugins;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.opensearch.common.CheckedFunction; import org.opensearch.common.CheckedFunction;
import org.opensearch.common.ParseField; import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.io.stream.NamedWriteable;
@ -62,6 +63,8 @@ import org.opensearch.search.fetch.FetchSubPhase;
import org.opensearch.search.fetch.subphase.highlight.Highlighter; import org.opensearch.search.fetch.subphase.highlight.Highlighter;
import org.opensearch.search.rescore.Rescorer; import org.opensearch.search.rescore.Rescorer;
import org.opensearch.search.rescore.RescorerBuilder; import org.opensearch.search.rescore.RescorerBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortParser;
import org.opensearch.search.suggest.Suggest; import org.opensearch.search.suggest.Suggest;
import org.opensearch.search.suggest.Suggester; import org.opensearch.search.suggest.Suggester;
import org.opensearch.search.suggest.SuggestionBuilder; import org.opensearch.search.suggest.SuggestionBuilder;
@ -138,6 +141,13 @@ public interface SearchPlugin {
return emptyList(); return emptyList();
} }
/**
* The new {@link Sort}s defined by this plugin.
*/
default List<SortSpec<?>> getSorts() {
return emptyList();
}
/** /**
* The new {@link Aggregation}s added by this plugin. * The new {@link Aggregation}s added by this plugin.
*/ */
@ -290,6 +300,38 @@ public interface SearchPlugin {
} }
} }
/**
* Specification of custom {@link Sort}.
*/
class SortSpec<T extends SortBuilder<T>> extends SearchExtensionSpec<T, SortParser<T>> {
/**
* Specification of custom {@link Sort}.
*
* @param name holds the names by which this sort might be parsed. The {@link ParseField#getPreferredName()} is special as it
* is the name by under which the reader is registered. So it is the name that the sort should use as its
* {@link NamedWriteable#getWriteableName()} too.
* @param reader the reader registered for this sort's builder. Typically a reference to a constructor that takes a
* {@link StreamInput}
* @param parser the parser the reads the sort builder from xcontent
*/
public SortSpec(ParseField name, Writeable.Reader<T> reader, SortParser<T> parser) {
super(name, reader, parser);
}
/**
* Specification of custom {@link Sort}.
*
* @param name the name by which this sort might be parsed or deserialized. Make sure that the query builder returns this name for
* {@link NamedWriteable#getWriteableName()}.
* @param reader the reader registered for this sort's builder. Typically a reference to a constructor that takes a
* {@link StreamInput}
* @param parser the parser the reads the sort builder from xcontent
*/
public SortSpec(String name, Writeable.Reader<T> reader, SortParser<T> parser) {
super(name, reader, parser);
}
}
/** /**
* Specification for an {@link Aggregation}. * Specification for an {@link Aggregation}.
*/ */

View File

@ -112,6 +112,7 @@ import org.opensearch.plugins.SearchPlugin.ScoreFunctionSpec;
import org.opensearch.plugins.SearchPlugin.SearchExtSpec; import org.opensearch.plugins.SearchPlugin.SearchExtSpec;
import org.opensearch.plugins.SearchPlugin.SearchExtensionSpec; import org.opensearch.plugins.SearchPlugin.SearchExtensionSpec;
import org.opensearch.plugins.SearchPlugin.SignificanceHeuristicSpec; import org.opensearch.plugins.SearchPlugin.SignificanceHeuristicSpec;
import org.opensearch.plugins.SearchPlugin.SortSpec;
import org.opensearch.plugins.SearchPlugin.SuggesterSpec; import org.opensearch.plugins.SearchPlugin.SuggesterSpec;
import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.BaseAggregationBuilder; import org.opensearch.search.aggregations.BaseAggregationBuilder;
@ -345,7 +346,7 @@ public class SearchModule {
registerScoreFunctions(plugins); registerScoreFunctions(plugins);
registerQueryParsers(plugins); registerQueryParsers(plugins);
registerRescorers(plugins); registerRescorers(plugins);
registerSorts(); registerSortParsers(plugins);
registerValueFormats(); registerValueFormats();
registerSignificanceHeuristics(plugins); registerSignificanceHeuristics(plugins);
this.valuesSourceRegistry = registerAggregations(plugins); this.valuesSourceRegistry = registerAggregations(plugins);
@ -882,13 +883,6 @@ public class SearchModule {
namedWriteables.add(new NamedWriteableRegistry.Entry(RescorerBuilder.class, spec.getName().getPreferredName(), spec.getReader())); namedWriteables.add(new NamedWriteableRegistry.Entry(RescorerBuilder.class, spec.getName().getPreferredName(), spec.getReader()));
} }
private void registerSorts() {
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, GeoDistanceSortBuilder.NAME, GeoDistanceSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScoreSortBuilder.NAME, ScoreSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScriptSortBuilder.NAME, ScriptSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, FieldSortBuilder.NAME, FieldSortBuilder::new));
}
private <T> void registerFromPlugin(List<SearchPlugin> plugins, Function<SearchPlugin, List<T>> producer, Consumer<T> consumer) { private <T> void registerFromPlugin(List<SearchPlugin> plugins, Function<SearchPlugin, List<T>> producer, Consumer<T> consumer) {
for (SearchPlugin plugin : plugins) { for (SearchPlugin plugin : plugins) {
for (T t : producer.apply(plugin)) { for (T t : producer.apply(plugin)) {
@ -1214,6 +1208,20 @@ public class SearchModule {
registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery); registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
} }
private void registerSortParsers(List<SearchPlugin> plugins) {
registerSort(new SortSpec<>(FieldSortBuilder.NAME, FieldSortBuilder::new, FieldSortBuilder::fromXContentObject));
registerSort(new SortSpec<>(ScriptSortBuilder.NAME, ScriptSortBuilder::new, ScriptSortBuilder::fromXContent));
registerSort(
new SortSpec<>(
new ParseField(GeoDistanceSortBuilder.NAME, GeoDistanceSortBuilder.ALTERNATIVE_NAME),
GeoDistanceSortBuilder::new,
GeoDistanceSortBuilder::fromXContent
)
);
registerSort(new SortSpec<>(ScoreSortBuilder.NAME, ScoreSortBuilder::new, ScoreSortBuilder::fromXContent));
registerFromPlugin(plugins, SearchPlugin::getSorts, this::registerSort);
}
private void registerIntervalsSourceProviders() { private void registerIntervalsSourceProviders() {
namedWriteables.addAll(getIntervalsSourceProviderNamedWritables()); namedWriteables.addAll(getIntervalsSourceProviderNamedWritables());
} }
@ -1260,6 +1268,17 @@ public class SearchModule {
namedXContents.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); namedXContents.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p)));
} }
private void registerSort(SortSpec<?> spec) {
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, spec.getName().getPreferredName(), spec.getReader()));
namedXContents.add(
new NamedXContentRegistry.Entry(
SortBuilder.class,
spec.getName(),
(p, c) -> spec.getParser().fromXContent(p, spec.getName().getPreferredName())
)
);
}
public FetchPhase getFetchPhase() { public FetchPhase getFetchPhase() {
return new FetchPhase(fetchSubPhases); return new FetchPhase(fetchSubPhases);
} }

View File

@ -41,6 +41,7 @@ import org.apache.lucene.search.SortField;
import org.opensearch.LegacyESVersion; import org.opensearch.LegacyESVersion;
import org.opensearch.OpenSearchParseException; import org.opensearch.OpenSearchParseException;
import org.opensearch.common.ParseField; import org.opensearch.common.ParseField;
import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.common.logging.DeprecationLogger;
@ -752,6 +753,27 @@ public class FieldSortBuilder extends SortBuilder<FieldSortBuilder> {
return PARSER.parse(parser, new FieldSortBuilder(fieldName), null); return PARSER.parse(parser, new FieldSortBuilder(fieldName), null);
} }
public static FieldSortBuilder fromXContentObject(XContentParser parser, String fieldName) throws IOException {
FieldSortBuilder builder = null;
String currentFieldName = null;
XContentParser.Token token;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.START_OBJECT) {
builder = fromXContent(parser, currentFieldName);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] does not support [" + currentFieldName + "]");
}
}
if (builder == null) {
throw new ParsingException(parser.getTokenLocation(), "Invalid " + NAME);
}
return builder;
}
private static final ObjectParser<FieldSortBuilder, Void> PARSER = new ObjectParser<>(NAME); private static final ObjectParser<FieldSortBuilder, Void> PARSER = new ObjectParser<>(NAME);
static { static {

View File

@ -41,6 +41,7 @@ import org.opensearch.common.ParsingException;
import org.opensearch.common.Strings; import org.opensearch.common.Strings;
import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.Queries;
import org.opensearch.common.xcontent.NamedObjectNotFoundException;
import org.opensearch.common.xcontent.ToXContentObject; import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; import org.opensearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested;
@ -53,13 +54,10 @@ import org.opensearch.search.DocValueFormat;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import static java.util.Collections.unmodifiableMap;
import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder;
public abstract class SortBuilder<T extends SortBuilder<T>> implements NamedWriteable, ToXContentObject, Rewriteable<SortBuilder<?>> { public abstract class SortBuilder<T extends SortBuilder<T>> implements NamedWriteable, ToXContentObject, Rewriteable<SortBuilder<?>> {
@ -71,17 +69,6 @@ public abstract class SortBuilder<T extends SortBuilder<T>> implements NamedWrit
public static final ParseField NESTED_FILTER_FIELD = new ParseField("nested_filter"); public static final ParseField NESTED_FILTER_FIELD = new ParseField("nested_filter");
public static final ParseField NESTED_PATH_FIELD = new ParseField("nested_path"); public static final ParseField NESTED_PATH_FIELD = new ParseField("nested_path");
private static final Map<String, Parser<?>> PARSERS;
static {
Map<String, Parser<?>> parsers = new HashMap<>();
parsers.put(ScriptSortBuilder.NAME, ScriptSortBuilder::fromXContent);
parsers.put(GeoDistanceSortBuilder.NAME, GeoDistanceSortBuilder::fromXContent);
parsers.put(GeoDistanceSortBuilder.ALTERNATIVE_NAME, GeoDistanceSortBuilder::fromXContent);
parsers.put(ScoreSortBuilder.NAME, ScoreSortBuilder::fromXContent);
// FieldSortBuilder gets involved if the user specifies a name that isn't one of these.
PARSERS = unmodifiableMap(parsers);
}
/** /**
* Create a {@linkplain SortFieldAndFormat} from this builder. * Create a {@linkplain SortFieldAndFormat} from this builder.
*/ */
@ -155,9 +142,10 @@ public abstract class SortBuilder<T extends SortBuilder<T>> implements NamedWrit
SortOrder order = SortOrder.fromString(parser.text()); SortOrder order = SortOrder.fromString(parser.text());
sortFields.add(fieldOrScoreSort(fieldName).order(order)); sortFields.add(fieldOrScoreSort(fieldName).order(order));
} else { } else {
if (PARSERS.containsKey(fieldName)) { try {
sortFields.add(PARSERS.get(fieldName).fromXContent(parser, fieldName)); SortBuilder<?> sort = parser.namedObject(SortBuilder.class, fieldName, null);
} else { sortFields.add(sort);
} catch (NamedObjectNotFoundException err) {
sortFields.add(FieldSortBuilder.fromXContent(parser, fieldName)); sortFields.add(FieldSortBuilder.fromXContent(parser, fieldName));
} }
} }
@ -290,11 +278,6 @@ public abstract class SortBuilder<T extends SortBuilder<T>> implements NamedWrit
} }
} }
@FunctionalInterface
private interface Parser<T extends SortBuilder<?>> {
T fromXContent(XContentParser parser, String elementName) throws IOException;
}
@Override @Override
public String toString() { public String toString() {
return Strings.toString(this, true, true); return Strings.toString(this, true, true);

View File

@ -0,0 +1,23 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.sort;
import org.opensearch.common.xcontent.XContentParser;
import java.io.IOException;
@FunctionalInterface
public interface SortParser<SB extends SortBuilder<SB>> {
/**
* Creates a new {@link SortBuilder} from the sort held by the
* {@link XContentParser}. The state on the parser contained in this context
* will be changed as a side effect of this method call
*/
SB fromXContent(XContentParser parser, String elementName) throws IOException;
}

View File

@ -127,6 +127,7 @@ public class SortBuilderTests extends OpenSearchTestCase {
result = parseSort(json); result = parseSort(json);
assertEquals(1, result.size()); assertEquals(1, result.size());
sortBuilder = result.get(0); sortBuilder = result.get(0);
assertWarnings("Deprecated field [_geoDistance] used, expected [_geo_distance] instead");
assertEquals(new GeoDistanceSortBuilder("pin.location", 40, -70), sortBuilder); assertEquals(new GeoDistanceSortBuilder("pin.location", 40, -70), sortBuilder);
json = "{ \"sort\" : [" + "{\"_geo_distance\" : {" + "\"pin.location\" : \"40,-70\" } }" + "] }"; json = "{ \"sort\" : [" + "{\"_geo_distance\" : {" + "\"pin.location\" : \"40,-70\" } }" + "] }";

View File

@ -0,0 +1,119 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.sort.plugin;
import static org.opensearch.common.xcontent.ConstructingObjectParser.constructorArg;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ConstructingObjectParser;
import org.opensearch.common.xcontent.ObjectParser;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.sort.BucketedSort;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortFieldAndFormat;
import org.opensearch.search.sort.SortOrder;
import java.io.IOException;
import java.util.Objects;
/**
* Custom sort builder that just rewrites to a basic field sort
*/
public class CustomSortBuilder extends SortBuilder<CustomSortBuilder> {
public static String NAME = "_custom";
public static ParseField SORT_FIELD = new ParseField("sort_field");
public final String field;
public final SortOrder order;
public CustomSortBuilder(String field, SortOrder order) {
this.field = field;
this.order = order;
}
public CustomSortBuilder(StreamInput in) throws IOException {
this.field = in.readString();
this.order = in.readOptionalWriteable(SortOrder::readFromStream);
}
@Override
public void writeTo(final StreamOutput out) throws IOException {
out.writeString(field);
out.writeOptionalWriteable(order);
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public SortBuilder<?> rewrite(final QueryRewriteContext ctx) throws IOException {
return SortBuilders.fieldSort(field).order(order);
}
@Override
protected SortFieldAndFormat build(final QueryShardContext context) throws IOException {
throw new IllegalStateException("rewrite");
}
@Override
public BucketedSort buildBucketedSort(final QueryShardContext context, final int bucketSize, final BucketedSort.ExtraData extra)
throws IOException {
throw new IllegalStateException("rewrite");
}
@Override
public boolean equals(Object object) {
if (this == object) {
return true;
}
if (object == null || getClass() != object.getClass()) {
return false;
}
CustomSortBuilder other = (CustomSortBuilder) object;
return Objects.equals(field, other.field) && Objects.equals(order, other.order);
}
@Override
public int hashCode() {
return Objects.hash(field, order);
}
@Override
public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(SORT_FIELD.getPreferredName(), field);
builder.field(ORDER_FIELD.getPreferredName(), order);
builder.endObject();
builder.endObject();
return builder;
}
public static CustomSortBuilder fromXContent(XContentParser parser, String elementName) {
return PARSER.apply(parser, null);
}
private static final ConstructingObjectParser<CustomSortBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME,
a -> new CustomSortBuilder((String) a[0], (SortOrder) a[1])
);
static {
PARSER.declareField(constructorArg(), XContentParser::text, SORT_FIELD, ObjectParser.ValueType.STRING);
PARSER.declareField(constructorArg(), p -> SortOrder.fromString(p.text()), ORDER_FIELD, ObjectParser.ValueType.STRING);
}
}

View File

@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.sort.plugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import java.util.Collections;
import java.util.List;
public class CustomSortPlugin extends Plugin implements SearchPlugin {
@Override
public List<SortSpec<?>> getSorts() {
return Collections.singletonList(new SortSpec<>(CustomSortBuilder.NAME, CustomSortBuilder::new, CustomSortBuilder::fromXContent));
}
}