Add regexp interval source (#1917)

* Add regexp interval source

Add a regexp interval source provider so people can use regular
expressions inside of intervals queries.

Signed-off-by: Matt Weber <matt@mattweber.org>

* Fixes

- register regexp interval in SearchModule
- use fully-qualified name for lucene RegExp
- get rid of unnecessary variable

Signed-off-by: Matt Weber <matt@mattweber.org>
This commit is contained in:
Matt Weber 2022-02-07 15:18:43 -08:00 committed by GitHub
parent 9c9e218ae6
commit b9420d8f70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 347 additions and 1 deletions

View File

@ -39,6 +39,7 @@ import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.queries.intervals.IntervalsSource; import org.apache.lucene.queries.intervals.IntervalsSource;
import org.apache.lucene.search.FuzzyQuery; import org.apache.lucene.search.FuzzyQuery;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.opensearch.LegacyESVersion; import org.opensearch.LegacyESVersion;
import org.opensearch.Version; import org.opensearch.Version;
import org.opensearch.common.ParseField; import org.opensearch.common.ParseField;
@ -101,12 +102,14 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
return Prefix.fromXContent(parser); return Prefix.fromXContent(parser);
case "wildcard": case "wildcard":
return Wildcard.fromXContent(parser); return Wildcard.fromXContent(parser);
case "regexp":
return Regexp.fromXContent(parser);
case "fuzzy": case "fuzzy":
return Fuzzy.fromXContent(parser); return Fuzzy.fromXContent(parser);
} }
throw new ParsingException( throw new ParsingException(
parser.getTokenLocation(), parser.getTokenLocation(),
"Unknown interval type [" + parser.currentName() + "], expecting one of [match, any_of, all_of, prefix, wildcard]" "Unknown interval type [" + parser.currentName() + "], expecting one of [match, any_of, all_of, prefix, wildcard, regexp]"
); );
} }
@ -631,6 +634,155 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
} }
} }
public static class Regexp extends IntervalsSourceProvider {
public static final String NAME = "regexp";
public static final int DEFAULT_FLAGS_VALUE = RegexpFlag.ALL.value();
private final String pattern;
private final int flags;
private final String useField;
private final Integer maxExpansions;
public Regexp(String pattern, int flags, String useField, Integer maxExpansions) {
this.pattern = pattern;
this.flags = flags;
this.useField = useField;
this.maxExpansions = (maxExpansions != null && maxExpansions > 0) ? maxExpansions : null;
}
public Regexp(StreamInput in) throws IOException {
this.pattern = in.readString();
this.flags = in.readVInt();
this.useField = in.readOptionalString();
this.maxExpansions = in.readOptionalVInt();
}
@Override
public IntervalsSource getSource(QueryShardContext context, MappedFieldType fieldType) {
final org.apache.lucene.util.automaton.RegExp regexp = new org.apache.lucene.util.automaton.RegExp(pattern, flags);
final CompiledAutomaton automaton = new CompiledAutomaton(regexp.toAutomaton());
if (useField != null) {
fieldType = context.fieldMapper(useField);
assert fieldType != null;
checkPositions(fieldType);
IntervalsSource regexpSource = maxExpansions == null
? Intervals.multiterm(automaton, regexp.toString())
: Intervals.multiterm(automaton, maxExpansions, regexp.toString());
return Intervals.fixField(useField, regexpSource);
} else {
checkPositions(fieldType);
return maxExpansions == null
? Intervals.multiterm(automaton, regexp.toString())
: Intervals.multiterm(automaton, maxExpansions, regexp.toString());
}
}
private void checkPositions(MappedFieldType type) {
if (type.getTextSearchInfo().hasPositions() == false) {
throw new IllegalArgumentException("Cannot create intervals over field [" + type.name() + "] with no positions indexed");
}
}
@Override
public void extractFields(Set<String> fields) {
if (useField != null) {
fields.add(useField);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regexp regexp = (Regexp) o;
return Objects.equals(pattern, regexp.pattern)
&& Objects.equals(flags, regexp.flags)
&& Objects.equals(useField, regexp.useField)
&& Objects.equals(maxExpansions, regexp.maxExpansions);
}
@Override
public int hashCode() {
return Objects.hash(pattern, flags, useField, maxExpansions);
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(pattern);
out.writeVInt(flags);
out.writeOptionalString(useField);
out.writeOptionalVInt(maxExpansions);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field("pattern", pattern);
if (flags != DEFAULT_FLAGS_VALUE) {
builder.field("flags_value", flags);
}
if (useField != null) {
builder.field("use_field", useField);
}
if (maxExpansions != null) {
builder.field("max_expansions", maxExpansions);
}
builder.endObject();
return builder;
}
private static final ConstructingObjectParser<Regexp, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
String pattern = (String) args[0];
String flags = (String) args[1];
Integer flagsValue = (Integer) args[2];
String useField = (String) args[3];
Integer maxExpansions = (Integer) args[4];
if (flagsValue != null) {
return new Regexp(pattern, flagsValue, useField, maxExpansions);
} else if (flags != null) {
return new Regexp(pattern, RegexpFlag.resolveValue(flags), useField, maxExpansions);
} else {
return new Regexp(pattern, DEFAULT_FLAGS_VALUE, useField, maxExpansions);
}
});
static {
PARSER.declareString(constructorArg(), new ParseField("pattern"));
PARSER.declareString(optionalConstructorArg(), new ParseField("flags"));
PARSER.declareInt(optionalConstructorArg(), new ParseField("flags_value"));
PARSER.declareString(optionalConstructorArg(), new ParseField("use_field"));
PARSER.declareInt(optionalConstructorArg(), new ParseField("max_expansions"));
}
public static Regexp fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
String getPattern() {
return pattern;
}
int getFlags() {
return flags;
}
String getUseField() {
return useField;
}
Integer getMaxExpansions() {
return maxExpansions;
}
}
public static class Wildcard extends IntervalsSourceProvider { public static class Wildcard extends IntervalsSourceProvider {
public static final String NAME = "wildcard"; public static final String NAME = "wildcard";

View File

@ -1254,6 +1254,11 @@ public class SearchModule {
IntervalsSourceProvider.Wildcard.NAME, IntervalsSourceProvider.Wildcard.NAME,
IntervalsSourceProvider.Wildcard::new IntervalsSourceProvider.Wildcard::new
), ),
new NamedWriteableRegistry.Entry(
IntervalsSourceProvider.class,
IntervalsSourceProvider.Regexp.NAME,
IntervalsSourceProvider.Regexp::new
),
new NamedWriteableRegistry.Entry( new NamedWriteableRegistry.Entry(
IntervalsSourceProvider.class, IntervalsSourceProvider.class,
IntervalsSourceProvider.Fuzzy.NAME, IntervalsSourceProvider.Fuzzy.NAME,

View File

@ -42,6 +42,8 @@ import org.apache.lucene.search.FuzzyQuery;
import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.apache.lucene.util.automaton.RegExp;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.Strings; import org.opensearch.common.Strings;
import org.opensearch.common.compress.CompressedXContent; import org.opensearch.common.compress.CompressedXContent;
@ -686,6 +688,114 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
}); });
} }
private static IntervalsSource buildRegexpSource(String pattern, int flags, Integer maxExpansions) {
final RegExp regexp = new RegExp(pattern, flags);
CompiledAutomaton automaton = new CompiledAutomaton(regexp.toAutomaton());
if (maxExpansions != null) {
return Intervals.multiterm(automaton, maxExpansions, regexp.toString());
} else {
return Intervals.multiterm(automaton, regexp.toString());
}
}
public void testRegexp() throws IOException {
final int DEFAULT_FLAGS = RegexpFlag.ALL.value();
String json = "{ \"intervals\" : { \"" + TEXT_FIELD_NAME + "\": { " + "\"regexp\" : { \"pattern\" : \"te.m\" } } } }";
IntervalQueryBuilder builder = (IntervalQueryBuilder) parseQuery(json);
Query expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", DEFAULT_FLAGS, null));
assertEquals(expected, builder.toQuery(createShardContext()));
String no_positions_json = "{ \"intervals\" : { \""
+ NO_POSITIONS_FIELD
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"[Tt]erm\" } } } }";
expectThrows(IllegalArgumentException.class, () -> {
IntervalQueryBuilder builder1 = (IntervalQueryBuilder) parseQuery(no_positions_json);
builder1.toQuery(createShardContext());
});
String fixed_field_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"use_field\" : \"masked_field\" } } } }";
builder = (IntervalQueryBuilder) parseQuery(fixed_field_json);
expected = new IntervalQuery(TEXT_FIELD_NAME, Intervals.fixField(MASKED_FIELD, buildRegexpSource("te.m", DEFAULT_FLAGS, null)));
assertEquals(expected, builder.toQuery(createShardContext()));
String fixed_field_json_no_positions = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"use_field\" : \""
+ NO_POSITIONS_FIELD
+ "\" } } } }";
expectThrows(IllegalArgumentException.class, () -> {
IntervalQueryBuilder builder1 = (IntervalQueryBuilder) parseQuery(fixed_field_json_no_positions);
builder1.toQuery(createShardContext());
});
String flags_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"flags\" : \"NONE\" } } } }";
builder = (IntervalQueryBuilder) parseQuery(flags_json);
expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", RegexpFlag.NONE.value(), null));
assertEquals(expected, builder.toQuery(createShardContext()));
String flags_value_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"flags_value\" : \""
+ RegexpFlag.ANYSTRING.value()
+ "\" } } } }";
builder = (IntervalQueryBuilder) parseQuery(flags_value_json);
expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", RegexpFlag.ANYSTRING.value(), null));
assertEquals(expected, builder.toQuery(createShardContext()));
String regexp_max_expand_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"max_expansions\" : 500 } } } }";
builder = (IntervalQueryBuilder) parseQuery(regexp_max_expand_json);
expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", DEFAULT_FLAGS, 500));
assertEquals(expected, builder.toQuery(createShardContext()));
String regexp_neg_max_expand_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"max_expansions\" : -20 } } } }";
builder = (IntervalQueryBuilder) parseQuery(regexp_neg_max_expand_json);
// max expansions use default
expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", DEFAULT_FLAGS, null));
assertEquals(expected, builder.toQuery(createShardContext()));
String regexp_over_max_expand_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"max_expansions\" : "
+ (BooleanQuery.getMaxClauseCount() + 1)
+ " } } } }";
expectThrows(IllegalArgumentException.class, () -> {
IntervalQueryBuilder builder1 = (IntervalQueryBuilder) parseQuery(regexp_over_max_expand_json);
builder1.toQuery(createShardContext());
});
String regexp_max_expand_with_flags_json = "{ \"intervals\" : { \""
+ TEXT_FIELD_NAME
+ "\": { "
+ "\"regexp\" : { \"pattern\" : \"te.m\", \"flags\": \"NONE\", \"max_expansions\" : 500 } } } }";
builder = (IntervalQueryBuilder) parseQuery(regexp_max_expand_with_flags_json);
expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", RegexpFlag.NONE.value(), 500));
assertEquals(expected, builder.toQuery(createShardContext()));
}
private static IntervalsSource buildFuzzySource(String term, String label, int prefixLength, boolean transpositions, int editDistance) { private static IntervalsSource buildFuzzySource(String term, String label, int prefixLength, boolean transpositions, int editDistance) {
FuzzyQuery fq = new FuzzyQuery(new Term("field", term), editDistance, prefixLength, 128, transpositions); FuzzyQuery fq = new FuzzyQuery(new Term("field", term), editDistance, prefixLength, 128, transpositions);
return Intervals.multiterm(fq.getAutomata(), label); return Intervals.multiterm(fq.getAutomata(), label);

View File

@ -0,0 +1,79 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.index.query;
import static org.opensearch.index.query.IntervalsSourceProvider.Regexp;
import static org.opensearch.index.query.IntervalsSourceProvider.fromXContent;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
public class RegexpIntervalsSourceProviderTests extends AbstractSerializingTestCase<Regexp> {
private static final List<String> FLAGS = Arrays.asList("INTERSECTION", "COMPLEMENT", "EMPTY", "ANYSTRING", "INTERVAL", "NONE");
@Override
protected Regexp createTestInstance() {
return createRandomRegexp();
}
static Regexp createRandomRegexp() {
return new Regexp(
randomAlphaOfLengthBetween(0, 3) + (randomBoolean() ? ".*?" : "." + randomAlphaOfLength(4)) + randomAlphaOfLengthBetween(0, 5),
randomBoolean() ? RegexpFlag.resolveValue(randomFrom(FLAGS)) : RegexpFlag.ALL.value(),
randomBoolean() ? randomAlphaOfLength(10) : null,
randomBoolean() ? randomIntBetween(-1, Integer.MAX_VALUE) : null
);
}
@Override
protected Regexp mutateInstance(Regexp instance) throws IOException {
String pattern = instance.getPattern();
int flags = instance.getFlags();
String useField = instance.getUseField();
Integer maxExpansions = instance.getMaxExpansions();
int ran = between(0, 3);
switch (ran) {
case 0:
pattern += randomBoolean() ? ".*?" : randomAlphaOfLength(5);
break;
case 1:
flags = (flags == RegexpFlag.ALL.value()) ? RegexpFlag.resolveValue(randomFrom(FLAGS)) : RegexpFlag.ALL.value();
break;
case 2:
useField = useField == null ? randomAlphaOfLength(5) : null;
break;
case 3:
maxExpansions = maxExpansions == null ? randomIntBetween(1, Integer.MAX_VALUE) : null;
break;
default:
throw new AssertionError("Illegal randomisation branch");
}
return new Regexp(pattern, flags, useField, maxExpansions);
}
@Override
protected Writeable.Reader<Regexp> instanceReader() {
return Regexp::new;
}
@Override
protected Regexp doParseInstance(XContentParser parser) throws IOException {
if (parser.nextToken() == XContentParser.Token.START_OBJECT) {
parser.nextToken();
}
Regexp regexp = (Regexp) fromXContent(parser);
assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken());
return regexp;
}
}