Refactor percolator's QueryAnalyzer to use QueryVisitors

Lucene now allows us to explore the structure of a query using QueryVisitors,
delegating the knowledge of how to recurse through and collect terms to the
query implementations themselves. The percolator currently has a home-grown
external version of this API to construct sets of matching terms that must be
present in a document in order for it to possibly match the query.

This commit removes the home-grown implementation in favour of one using
QueryVisitor. This has the added benefit of making interval queries available
for percolator pre-filtering. Due to a bug in multi-term intervals (LUCENE-9050)
it also includes a clone of some of the lucene intervals logic, that can be removed
once upstream has been fixed.

Closes #45639
This commit is contained in:
Alan Woodward 2019-11-20 09:21:01 +00:00 committed by GitHub
parent 9c0ec7ce23
commit c6b31162ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1134 additions and 446 deletions

View File

@ -446,10 +446,9 @@ public class PercolatorFieldMapper extends FieldMapper {
ParseContext.Document doc = context.doc();
FieldType pft = (FieldType) this.fieldType();
QueryAnalyzer.Result result;
try {
Version indexVersion = context.mapperService().getIndexSettings().getIndexVersionCreated();
result = QueryAnalyzer.analyze(query, indexVersion);
} catch (QueryAnalyzer.UnsupportedQueryException e) {
Version indexVersion = context.mapperService().getIndexSettings().getIndexVersionCreated();
result = QueryAnalyzer.analyze(query, indexVersion);
if (result == QueryAnalyzer.Result.UNKNOWN) {
doc.add(new Field(pft.extractionResultField.name(), EXTRACTION_FAILED, extractionResultField.fieldType()));
return;
}

View File

@ -19,84 +19,41 @@
package org.elasticsearch.percolator;
import org.apache.lucene.document.BinaryRange;
import org.apache.lucene.index.PrefixCodedTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.BlendedTermQuery;
import org.apache.lucene.queries.CommonTermsQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.MultiPhraseQuery;
import org.apache.lucene.search.PhraseQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.SynonymQuery;
import org.apache.lucene.search.TermInSetQuery;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanNotQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import static java.util.stream.Collectors.toSet;
final class QueryAnalyzer {
private static final Map<Class<? extends Query>, BiFunction<Query, Version, Result>> queryProcessors;
static {
Map<Class<? extends Query>, BiFunction<Query, Version, Result>> map = new HashMap<>();
map.put(MatchNoDocsQuery.class, matchNoDocsQuery());
map.put(MatchAllDocsQuery.class, matchAllDocsQuery());
map.put(ConstantScoreQuery.class, constantScoreQuery());
map.put(BoostQuery.class, boostQuery());
map.put(TermQuery.class, termQuery());
map.put(TermInSetQuery.class, termInSetQuery());
map.put(CommonTermsQuery.class, commonTermsQuery());
map.put(BlendedTermQuery.class, blendedTermQuery());
map.put(PhraseQuery.class, phraseQuery());
map.put(MultiPhraseQuery.class, multiPhraseQuery());
map.put(SpanTermQuery.class, spanTermQuery());
map.put(SpanNearQuery.class, spanNearQuery());
map.put(SpanOrQuery.class, spanOrQuery());
map.put(SpanFirstQuery.class, spanFirstQuery());
map.put(SpanNotQuery.class, spanNotQuery());
map.put(BooleanQuery.class, booleanQuery());
map.put(DisjunctionMaxQuery.class, disjunctionMaxQuery());
map.put(SynonymQuery.class, synonymQuery());
map.put(FunctionScoreQuery.class, functionScoreQuery());
map.put(PointRangeQuery.class, pointRangeQuery());
map.put(IndexOrDocValuesQuery.class, indexOrDocValuesQuery());
map.put(ESToParentBlockJoinQuery.class, toParentBlockJoinQuery());
queryProcessors = Collections.unmodifiableMap(map);
}
private QueryAnalyzer() {
}
@ -128,299 +85,155 @@ final class QueryAnalyzer {
* @param indexVersion The create version of the index containing the percolator queries.
*/
static Result analyze(Query query, Version indexVersion) {
Class<?> queryClass = query.getClass();
if (queryClass.isAnonymousClass()) {
// Sometimes queries have anonymous classes in that case we need the direct super class.
// (for example blended term query)
queryClass = queryClass.getSuperclass();
ResultBuilder builder = new ResultBuilder(indexVersion, false);
query.visit(builder);
return builder.getResult();
}
private static final Set<Class<?>> verifiedQueries = new HashSet<>(Arrays.asList(
TermQuery.class, TermInSetQuery.class, SynonymQuery.class, SpanTermQuery.class, SpanOrQuery.class,
BooleanQuery.class, DisjunctionMaxQuery.class, ConstantScoreQuery.class, BoostQuery.class,
BlendedTermQuery.class
));
private static boolean isVerified(Query query) {
if (query instanceof FunctionScoreQuery) {
return ((FunctionScoreQuery)query).getMinScore() == null;
}
BiFunction<Query, Version, Result> queryProcessor = queryProcessors.get(queryClass);
if (queryProcessor != null) {
return queryProcessor.apply(query, indexVersion);
} else {
throw new UnsupportedQueryException(query);
for (Class<?> cls : verifiedQueries) {
if (cls.isAssignableFrom(query.getClass())) {
return true;
}
}
return false;
}
private static BiFunction<Query, Version, Result> matchNoDocsQuery() {
return (query, version) -> new Result(true, Collections.emptySet(), 0);
}
private static class ResultBuilder extends QueryVisitor {
private static BiFunction<Query, Version, Result> matchAllDocsQuery() {
return (query, version) -> new Result(true, true);
}
final boolean conjunction;
final Version version;
List<ResultBuilder> children = new ArrayList<>();
boolean verified = true;
int minimumShouldMatch = 0;
List<Result> terms = new ArrayList<>();
private static BiFunction<Query, Version, Result> constantScoreQuery() {
return (query, boosts) -> {
Query wrappedQuery = ((ConstantScoreQuery) query).getQuery();
return analyze(wrappedQuery, boosts);
};
}
private ResultBuilder(Version version, boolean conjunction) {
this.conjunction = conjunction;
this.version = version;
}
private static BiFunction<Query, Version, Result> boostQuery() {
return (query, version) -> {
Query wrappedQuery = ((BoostQuery) query).getQuery();
return analyze(wrappedQuery, version);
};
}
@Override
public String toString() {
return (conjunction ? "CONJ" : "DISJ") + children + terms + "~" + minimumShouldMatch;
}
private static BiFunction<Query, Version, Result> termQuery() {
return (query, version) -> {
TermQuery termQuery = (TermQuery) query;
return new Result(true, Collections.singleton(new QueryExtraction(termQuery.getTerm())), 1);
};
}
private static BiFunction<Query, Version, Result> termInSetQuery() {
return (query, version) -> {
TermInSetQuery termInSetQuery = (TermInSetQuery) query;
Set<QueryExtraction> terms = new HashSet<>();
PrefixCodedTerms.TermIterator iterator = termInSetQuery.getTermData().iterator();
for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
terms.add(new QueryExtraction(new Term(iterator.field(), term)));
Result getResult() {
List<Result> partialResults = new ArrayList<>();
if (terms.size() > 0) {
partialResults.add(conjunction ? handleConjunction(terms, version) :
handleDisjunction(terms, minimumShouldMatch, version));
}
return new Result(true, terms, Math.min(1, terms.size()));
};
}
private static BiFunction<Query, Version, Result> synonymQuery() {
return (query, version) -> {
Set<QueryExtraction> terms = ((SynonymQuery) query).getTerms().stream().map(QueryExtraction::new).collect(toSet());
return new Result(true, terms, Math.min(1, terms.size()));
};
}
private static BiFunction<Query, Version, Result> commonTermsQuery() {
return (query, version) -> {
Set<QueryExtraction> terms = ((CommonTermsQuery) query).getTerms().stream().map(QueryExtraction::new).collect(toSet());
return new Result(false, terms, Math.min(1, terms.size()));
};
}
private static BiFunction<Query, Version, Result> blendedTermQuery() {
return (query, version) -> {
Set<QueryExtraction> terms = ((BlendedTermQuery) query).getTerms().stream().map(QueryExtraction::new).collect(toSet());
return new Result(true, terms, Math.min(1, terms.size()));
};
}
private static BiFunction<Query, Version, Result> phraseQuery() {
return (query, version) -> {
Term[] terms = ((PhraseQuery) query).getTerms();
if (terms.length == 0) {
return new Result(true, Collections.emptySet(), 0);
if (children.isEmpty() == false) {
List<Result> childResults = children.stream().map(ResultBuilder::getResult).collect(Collectors.toList());
partialResults.addAll(childResults);
}
if (version.onOrAfter(Version.V_6_1_0)) {
Set<QueryExtraction> extractions = Arrays.stream(terms).map(QueryExtraction::new).collect(toSet());
return new Result(false, extractions, extractions.size());
} else {
// the longest term is likely to be the rarest,
// so from a performance perspective it makes sense to extract that
Term longestTerm = terms[0];
for (Term term : terms) {
if (longestTerm.bytes().length < term.bytes().length) {
longestTerm = term;
}
}
return new Result(false, Collections.singleton(new QueryExtraction(longestTerm)), 1);
if (partialResults.isEmpty()) {
return verified ? Result.MATCH_NONE : Result.UNKNOWN;
}
};
}
private static BiFunction<Query, Version, Result> multiPhraseQuery() {
return (query, version) -> {
Term[][] terms = ((MultiPhraseQuery) query).getTermArrays();
if (terms.length == 0) {
return new Result(true, Collections.emptySet(), 0);
}
// This query has the same problem as boolean queries when it comes to duplicated terms
// So to keep things simple, we just rewrite to a boolean query
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (Term[] termArr : terms) {
BooleanQuery.Builder subBuilder = new BooleanQuery.Builder();
for (Term term : termArr) {
subBuilder.add(new TermQuery(term), Occur.SHOULD);
}
builder.add(subBuilder.build(), Occur.FILTER);
}
// Make sure to unverify the result
return booleanQuery().apply(builder.build(), version).unverify();
};
}
private static BiFunction<Query, Version, Result> spanTermQuery() {
return (query, version) -> {
Term term = ((SpanTermQuery) query).getTerm();
return new Result(true, Collections.singleton(new QueryExtraction(term)), 1);
};
}
private static BiFunction<Query, Version, Result> spanNearQuery() {
return (query, version) -> {
SpanNearQuery spanNearQuery = (SpanNearQuery) query;
if (version.onOrAfter(Version.V_6_1_0)) {
// This has the same problem as boolean queries when it comes to duplicated clauses
// so we rewrite to a boolean query to keep things simple.
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (SpanQuery clause : spanNearQuery.getClauses()) {
builder.add(clause, Occur.FILTER);
}
// make sure to unverify the result
return booleanQuery().apply(builder.build(), version).unverify();
} else {
Result bestClause = null;
for (SpanQuery clause : spanNearQuery.getClauses()) {
Result temp = analyze(clause, version);
bestClause = selectBestResult(temp, bestClause);
}
return bestClause;
}
};
}
private static BiFunction<Query, Version, Result> spanOrQuery() {
return (query, version) -> {
SpanOrQuery spanOrQuery = (SpanOrQuery) query;
// handle it like a boolean query to not dulplicate eg. logic
// about duplicated terms
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (SpanQuery clause : spanOrQuery.getClauses()) {
builder.add(clause, Occur.SHOULD);
}
return booleanQuery().apply(builder.build(), version);
};
}
private static BiFunction<Query, Version, Result> spanNotQuery() {
return (query, version) -> {
Result result = analyze(((SpanNotQuery) query).getInclude(), version);
return new Result(false, result.extractions, result.minimumShouldMatch);
};
}
private static BiFunction<Query, Version, Result> spanFirstQuery() {
return (query, version) -> {
Result result = analyze(((SpanFirstQuery) query).getMatch(), version);
return new Result(false, result.extractions, result.minimumShouldMatch);
};
}
private static BiFunction<Query, Version, Result> booleanQuery() {
return (query, version) -> {
BooleanQuery bq = (BooleanQuery) query;
int minimumShouldMatch = bq.getMinimumNumberShouldMatch();
List<Query> requiredClauses = new ArrayList<>();
List<Query> optionalClauses = new ArrayList<>();
boolean hasProhibitedClauses = false;
for (BooleanClause clause : bq.clauses()) {
if (clause.isRequired()) {
requiredClauses.add(clause.getQuery());
} else if (clause.isProhibited()) {
hasProhibitedClauses = true;
} else {
assert clause.getOccur() == Occur.SHOULD;
optionalClauses.add(clause.getQuery());
}
}
if (minimumShouldMatch > optionalClauses.size()
|| (requiredClauses.isEmpty() && optionalClauses.isEmpty())) {
return new Result(false, Collections.emptySet(), 0);
}
if (requiredClauses.size() > 0) {
if (minimumShouldMatch > 0) {
// mix of required clauses and required optional clauses, we turn it into
// a pure conjunction by moving the optional clauses to a sub query to
// simplify logic
BooleanQuery.Builder minShouldMatchQuery = new BooleanQuery.Builder();
minShouldMatchQuery.setMinimumNumberShouldMatch(minimumShouldMatch);
for (Query q : optionalClauses) {
minShouldMatchQuery.add(q, Occur.SHOULD);
}
requiredClauses.add(minShouldMatchQuery.build());
optionalClauses.clear();
minimumShouldMatch = 0;
} else {
optionalClauses.clear(); // only matter for scoring, not matching
}
}
// Now we now have either a pure conjunction or a pure disjunction, with at least one clause
Result result;
if (requiredClauses.size() > 0) {
assert optionalClauses.isEmpty();
assert minimumShouldMatch == 0;
result = handleConjunctionQuery(requiredClauses, version);
if (partialResults.size() == 1) {
result = partialResults.get(0);
} else {
assert requiredClauses.isEmpty();
if (minimumShouldMatch == 0) {
// Lucene always requires one matching clause for disjunctions
minimumShouldMatch = 1;
}
result = handleDisjunctionQuery(optionalClauses, minimumShouldMatch, version);
result = conjunction ? handleConjunction(partialResults, version)
: handleDisjunction(partialResults, minimumShouldMatch, version);
}
if (hasProhibitedClauses) {
if (verified == false) {
result = result.unverify();
}
return result;
};
}
@Override
public QueryVisitor getSubVisitor(Occur occur, Query parent) {
this.verified = isVerified(parent);
if (occur == Occur.MUST || occur == Occur.FILTER) {
ResultBuilder builder = new ResultBuilder(version, true);
children.add(builder);
return builder;
}
if (occur == Occur.MUST_NOT) {
this.verified = false;
return QueryVisitor.EMPTY_VISITOR;
}
int minimumShouldMatch = 0;
if (parent instanceof BooleanQuery) {
BooleanQuery bq = (BooleanQuery) parent;
if (bq.getMinimumNumberShouldMatch() == 0
&& bq.clauses().stream().anyMatch(c -> c.getOccur() == Occur.MUST || c.getOccur() == Occur.FILTER)) {
return QueryVisitor.EMPTY_VISITOR;
}
minimumShouldMatch = bq.getMinimumNumberShouldMatch();
}
ResultBuilder child = new ResultBuilder(version, false);
child.minimumShouldMatch = minimumShouldMatch;
children.add(child);
return child;
}
@Override
public void visitLeaf(Query query) {
if (query instanceof MatchAllDocsQuery) {
terms.add(new Result(true, true));
}
else if (query instanceof MatchNoDocsQuery) {
terms.add(Result.MATCH_NONE);
}
else if (query instanceof PointRangeQuery) {
terms.add(pointRangeQuery((PointRangeQuery)query));
}
else {
terms.add(Result.UNKNOWN);
}
}
@Override
public void consumeTerms(Query query, Term... terms) {
boolean verified = isVerified(query);
Set<QueryExtraction> qe = Arrays.stream(terms).map(QueryExtraction::new).collect(Collectors.toSet());
if (qe.size() > 0) {
if (version.before(Version.V_6_1_0) && conjunction) {
Optional<QueryExtraction> longest = qe.stream()
.filter(q -> q.term != null)
.max(Comparator.comparingInt(q -> q.term.bytes().length));
if (longest.isPresent()) {
qe = Collections.singleton(longest.get());
}
}
this.terms.add(new Result(verified, qe, conjunction ? qe.size() : 1));
}
}
}
private static BiFunction<Query, Version, Result> disjunctionMaxQuery() {
return (query, version) -> {
List<Query> disjuncts = ((DisjunctionMaxQuery) query).getDisjuncts();
if (disjuncts.isEmpty()) {
return new Result(false, Collections.emptySet(), 0);
} else {
return handleDisjunctionQuery(disjuncts, 1, version);
}
};
}
private static Result pointRangeQuery(PointRangeQuery query) {
if (query.getNumDims() != 1) {
return Result.UNKNOWN;
}
private static BiFunction<Query, Version, Result> functionScoreQuery() {
return (query, version) -> {
FunctionScoreQuery functionScoreQuery = (FunctionScoreQuery) query;
Result result = analyze(functionScoreQuery.getSubQuery(), version);
byte[] lowerPoint = query.getLowerPoint();
byte[] upperPoint = query.getUpperPoint();
// If min_score is specified we can't guarantee upfront that this percolator query matches,
// so in that case we set verified to false.
// (if it matches with the percolator document matches with the extracted terms.
// Min score filters out docs, which is different than the functions, which just influences the score.)
boolean verified = result.verified && functionScoreQuery.getMinScore() == null;
if (result.matchAllDocs) {
return new Result(result.matchAllDocs, verified);
} else {
return new Result(verified, result.extractions, result.minimumShouldMatch);
}
};
}
// Need to check whether upper is not smaller than lower, otherwise NumericUtils.subtract(...) fails IAE
// If upper is really smaller than lower then we deal with like MatchNoDocsQuery. (verified and no extractions)
if (new BytesRef(lowerPoint).compareTo(new BytesRef(upperPoint)) > 0) {
return new Result(true, Collections.emptySet(), 0);
}
private static BiFunction<Query, Version, Result> pointRangeQuery() {
return (query, version) -> {
PointRangeQuery pointRangeQuery = (PointRangeQuery) query;
if (pointRangeQuery.getNumDims() != 1) {
throw new UnsupportedQueryException(query);
}
byte[] lowerPoint = pointRangeQuery.getLowerPoint();
byte[] upperPoint = pointRangeQuery.getUpperPoint();
// Need to check whether upper is not smaller than lower, otherwise NumericUtils.subtract(...) fails IAE
// If upper is really smaller than lower then we deal with like MatchNoDocsQuery. (verified and no extractions)
if (new BytesRef(lowerPoint).compareTo(new BytesRef(upperPoint)) > 0) {
return new Result(true, Collections.emptySet(), 0);
}
byte[] interval = new byte[16];
NumericUtils.subtract(16, 0, prepad(upperPoint), prepad(lowerPoint), interval);
return new Result(false, Collections.singleton(new QueryExtraction(
new Range(pointRangeQuery.getField(), lowerPoint, upperPoint, interval))), 1);
};
byte[] interval = new byte[16];
NumericUtils.subtract(16, 0, prepad(upperPoint), prepad(lowerPoint), interval);
return new Result(false, Collections.singleton(new QueryExtraction(
new Range(query.getField(), lowerPoint, upperPoint, interval))), 1);
}
private static byte[] prepad(byte[] original) {
@ -430,58 +243,16 @@ final class QueryAnalyzer {
return result;
}
private static BiFunction<Query, Version, Result> indexOrDocValuesQuery() {
return (query, version) -> {
IndexOrDocValuesQuery indexOrDocValuesQuery = (IndexOrDocValuesQuery) query;
return analyze(indexOrDocValuesQuery.getIndexQuery(), version);
};
}
private static BiFunction<Query, Version, Result> toParentBlockJoinQuery() {
return (query, version) -> {
ESToParentBlockJoinQuery toParentBlockJoinQuery = (ESToParentBlockJoinQuery) query;
Result result = analyze(toParentBlockJoinQuery.getChildQuery(), version);
return new Result(false, result.extractions, result.minimumShouldMatch);
};
}
private static Result handleConjunctionQuery(List<Query> conjunctions, Version version) {
UnsupportedQueryException uqe = null;
List<Result> results = new ArrayList<>(conjunctions.size());
boolean success = false;
for (Query query : conjunctions) {
try {
Result subResult = analyze(query, version);
if (subResult.isMatchNoDocs()) {
return subResult;
}
results.add(subResult);
success = true;
} catch (UnsupportedQueryException e) {
uqe = e;
}
}
if (success == false) {
// No clauses could be extracted
if (uqe != null) {
throw uqe;
} else {
// Empty conjunction
return new Result(true, Collections.emptySet(), 0);
}
}
Result result = handleConjunction(results, version);
if (uqe != null) {
result = result.unverify();
}
return result;
}
private static Result handleConjunction(List<Result> conjunctions, Version version) {
private static Result handleConjunction(List<Result> conjunctionsWithUnknowns, Version version) {
List<Result> conjunctions = conjunctionsWithUnknowns.stream().filter(r -> r.isUnknown() == false).collect(Collectors.toList());
if (conjunctions.isEmpty()) {
throw new IllegalArgumentException("Must have at least on conjunction sub result");
if (conjunctionsWithUnknowns.isEmpty()) {
throw new IllegalArgumentException("Must have at least on conjunction sub result");
}
return conjunctionsWithUnknowns.get(0); // all conjunctions are unknown, so just return the first one
}
if (conjunctionsWithUnknowns.size() == 1) {
return conjunctionsWithUnknowns.get(0);
}
if (version.onOrAfter(Version.V_6_1_0)) {
for (Result subResult : conjunctions) {
@ -490,7 +261,7 @@ final class QueryAnalyzer {
}
}
int msm = 0;
boolean verified = true;
boolean verified = conjunctionsWithUnknowns.size() == conjunctions.size();
boolean matchAllDocs = true;
boolean hasDuplicateTerms = false;
Set<QueryExtraction> extractions = new HashSet<>();
@ -523,21 +294,19 @@ final class QueryAnalyzer {
resultMsm = 1;
} else {
resultMsm = 0;
verified = false;
break;
}
}
if (extractions.contains(queryExtraction)) {
resultMsm = 0;
resultMsm = Math.max(0, resultMsm - 1);
verified = false;
break;
}
}
msm += resultMsm;
if (result.verified == false
// If some inner extractions are optional, the result can't be verified
|| result.minimumShouldMatch < result.extractions.size()) {
// If some inner extractions are optional, the result can't be verified
|| result.minimumShouldMatch < result.extractions.size()) {
verified = false;
}
matchAllDocs &= result.matchAllDocs;
@ -548,6 +317,7 @@ final class QueryAnalyzer {
} else {
return new Result(verified, extractions, hasDuplicateTerms ? 1 : msm);
}
} else {
Result bestClause = null;
for (Result result : conjunctions) {
@ -557,17 +327,13 @@ final class QueryAnalyzer {
}
}
private static Result handleDisjunctionQuery(List<Query> disjunctions, int requiredShouldClauses, Version version) {
List<Result> subResults = new ArrayList<>();
for (Query query : disjunctions) {
// if either query fails extraction, we need to propagate as we could miss hits otherwise
Result subResult = analyze(query, version);
subResults.add(subResult);
}
return handleDisjunction(subResults, requiredShouldClauses, version);
}
private static Result handleDisjunction(List<Result> disjunctions, int requiredShouldClauses, Version version) {
if (disjunctions.stream().anyMatch(Result::isUnknown)) {
return Result.UNKNOWN;
}
if (disjunctions.size() == 1) {
return disjunctions.get(0);
}
// Keep track of the msm for each clause:
List<Integer> clauses = new ArrayList<>(disjunctions.size());
boolean verified;
@ -764,7 +530,7 @@ final class QueryAnalyzer {
final int minimumShouldMatch;
final boolean matchAllDocs;
private Result(boolean matchAllDocs, boolean verified, Set<QueryExtraction> extractions, int minimumShouldMatch) {
Result(boolean matchAllDocs, boolean verified, Set<QueryExtraction> extractions, int minimumShouldMatch) {
if (minimumShouldMatch > extractions.size()) {
throw new IllegalArgumentException("minimumShouldMatch can't be greater than the number of extractions: "
+ minimumShouldMatch + " > " + extractions.size());
@ -783,6 +549,11 @@ final class QueryAnalyzer {
this(matchAllDocs, verified, Collections.emptySet(), 0);
}
@Override
public String toString() {
return extractions.toString();
}
Result unverify() {
if (verified) {
return new Result(matchAllDocs, false, extractions, minimumShouldMatch);
@ -791,9 +562,37 @@ final class QueryAnalyzer {
}
}
boolean isUnknown() {
return false;
}
boolean isMatchNoDocs() {
return matchAllDocs == false && extractions.isEmpty();
}
static final Result UNKNOWN = new Result(false, false, Collections.emptySet(), 0){
@Override
boolean isUnknown() {
return true;
}
@Override
boolean isMatchNoDocs() {
return false;
}
@Override
public String toString() {
return "UNKNOWN";
}
};
static final Result MATCH_NONE = new Result(false, true, Collections.emptySet(), 0) {
@Override
boolean isMatchNoDocs() {
return true;
}
};
}
static class QueryExtraction {
@ -846,26 +645,6 @@ final class QueryAnalyzer {
}
}
/**
* Exception indicating that none or some query terms couldn't extracted from a percolator query.
*/
static class UnsupportedQueryException extends RuntimeException {
private final Query unsupportedQuery;
UnsupportedQueryException(Query unsupportedQuery) {
super(LoggerMessageFormat.format("no query terms can be extracted from query [{}]", unsupportedQuery));
this.unsupportedQuery = unsupportedQuery;
}
/**
* The actual Lucene query that was unsupported and caused this exception to be thrown.
*/
Query getUnsupportedQuery() {
return unsupportedQuery;
}
}
static class Range {
final String fieldName;

View File

@ -892,7 +892,7 @@ public class PercolatorFieldMapperTests extends ESSingleNodeTestCase {
assertThat(values.get(1), equalTo("field\0value2"));
assertThat(values.get(2), equalTo("field\0value3"));
int msm = doc.rootDoc().getFields(fieldType.minimumShouldMatchField.name())[0].numericValue().intValue();
assertThat(msm, equalTo(2));
assertThat(msm, equalTo(3));
qb = boolQuery()
.must(boolQuery().must(termQuery("field", "value1")).must(termQuery("field", "value2")))
@ -916,7 +916,7 @@ public class PercolatorFieldMapperTests extends ESSingleNodeTestCase {
assertThat(values.get(3), equalTo("field\0value4"));
assertThat(values.get(4), equalTo("field\0value5"));
msm = doc.rootDoc().getFields(fieldType.minimumShouldMatchField.name())[0].numericValue().intValue();
assertThat(msm, equalTo(2));
assertThat(msm, equalTo(4));
qb = boolQuery()
.minimumShouldMatch(3)

View File

@ -29,6 +29,10 @@ import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.BlendedTermQuery;
import org.apache.lucene.queries.CommonTermsQuery;
import org.apache.lucene.queries.XIntervals;
import org.apache.lucene.queries.intervals.IntervalQuery;
import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.queries.intervals.IntervalsSource;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
@ -73,7 +77,6 @@ import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.elasticsearch.percolator.QueryAnalyzer.UnsupportedQueryException;
import static org.elasticsearch.percolator.QueryAnalyzer.analyze;
import static org.elasticsearch.percolator.QueryAnalyzer.selectBestResult;
import static org.hamcrest.Matchers.equalTo;
@ -450,7 +453,7 @@ public class QueryAnalyzerTests extends ESTestCase {
builder.add(termQuery2, BooleanClause.Occur.SHOULD);
builder.add(termQuery3, BooleanClause.Occur.SHOULD);
result = analyze(builder.build(), Version.CURRENT);
assertThat("Minimum match has not impact on whether the result is verified", result.verified, is(true));
assertThat("Minimum match has no impact on whether the result is verified", result.verified, is(true));
assertThat("msm is at least two so result.minimumShouldMatch should 2 too", result.minimumShouldMatch, equalTo(msm));
builder = new BooleanQuery.Builder();
@ -827,18 +830,14 @@ public class QueryAnalyzerTests extends ESTestCase {
public void testExtractQueryMetadata_unsupportedQuery() {
TermRangeQuery termRangeQuery = new TermRangeQuery("_field", null, null, true, false);
UnsupportedQueryException e = expectThrows(UnsupportedQueryException.class,
() -> analyze(termRangeQuery, Version.CURRENT));
assertThat(e.getUnsupportedQuery(), sameInstance(termRangeQuery));
assertEquals(Result.UNKNOWN, analyze(termRangeQuery, Version.CURRENT));
TermQuery termQuery1 = new TermQuery(new Term("_field", "_term"));
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(termQuery1, BooleanClause.Occur.SHOULD);
builder.add(termRangeQuery, BooleanClause.Occur.SHOULD);
BooleanQuery bq = builder.build();
e = expectThrows(UnsupportedQueryException.class, () -> analyze(bq, Version.CURRENT));
assertThat(e.getUnsupportedQuery(), sameInstance(termRangeQuery));
assertEquals(Result.UNKNOWN, analyze(bq, Version.CURRENT));
}
public void testExtractQueryMetadata_unsupportedQueryInBoolQueryWithMustClauses() {
@ -870,8 +869,7 @@ public class QueryAnalyzerTests extends ESTestCase {
builder.add(unsupportedQuery, BooleanClause.Occur.MUST);
builder.add(unsupportedQuery, BooleanClause.Occur.MUST);
BooleanQuery bq2 = builder.build();
UnsupportedQueryException e = expectThrows(UnsupportedQueryException.class, () -> analyze(bq2, Version.CURRENT));
assertThat(e.getUnsupportedQuery(), sameInstance(unsupportedQuery));
assertEquals(Result.UNKNOWN, analyze(bq2, Version.CURRENT));
}
public void testExtractQueryMetadata_disjunctionMaxQuery() {
@ -1173,10 +1171,10 @@ public class QueryAnalyzerTests extends ESTestCase {
public void testTooManyPointDimensions() {
// For now no extraction support for geo queries:
Query query1 = LatLonPoint.newBoxQuery("_field", 0, 1, 0, 1);
expectThrows(UnsupportedQueryException.class, () -> analyze(query1, Version.CURRENT));
assertEquals(Result.UNKNOWN, analyze(query1, Version.CURRENT));
Query query2 = LongPoint.newRangeQuery("_field", new long[]{0, 0, 0}, new long[]{1, 1, 1});
expectThrows(UnsupportedQueryException.class, () -> analyze(query2, Version.CURRENT));
assertEquals(Result.UNKNOWN, analyze(query2, Version.CURRENT));
}
public void testPointRangeQuery_lowerUpperReversed() {
@ -1338,7 +1336,7 @@ public class QueryAnalyzerTests extends ESTestCase {
Result result = analyze(builder.build(), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(2));
assertThat(result.minimumShouldMatch, equalTo(4));
assertTermsEqual(result.extractions, new Term("field", "value1"), new Term("field", "value2"),
new Term("field", "value3"), new Term("field", "value4"));
@ -1375,16 +1373,10 @@ public class QueryAnalyzerTests extends ESTestCase {
public void testEmptyQueries() {
BooleanQuery.Builder builder = new BooleanQuery.Builder();
Result result = analyze(builder.build(), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(0));
assertThat(result.extractions.size(), equalTo(0));
assertEquals(result, Result.MATCH_NONE);
result = analyze(new DisjunctionMaxQuery(Collections.emptyList(), 0f), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(0));
assertThat(result.extractions.size(), equalTo(0));
assertEquals(result, Result.MATCH_NONE);
}
private static void assertDimension(byte[] expected, Consumer<byte[]> consumer) {
@ -1410,4 +1402,108 @@ public class QueryAnalyzerTests extends ESTestCase {
return queryExtractions;
}
public void testIntervalQueries() {
IntervalsSource source = Intervals.or(Intervals.term("term1"), Intervals.term("term2"));
Result result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "term1"), new Term("field", "term2"));
source = Intervals.ordered(Intervals.term("term1"), Intervals.term("term2"),
Intervals.or(Intervals.term("term3"), Intervals.term("term4")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(3));
assertTermsEqual(result.extractions, new Term("field", "term1"), new Term("field", "term2"),
new Term("field", "term3"), new Term("field", "term4"));
source = Intervals.ordered(Intervals.term("term1"), XIntervals.wildcard(new BytesRef("a*")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "term1"));
source = Intervals.ordered(XIntervals.wildcard(new BytesRef("a*")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertEquals(Result.UNKNOWN, result);
source = Intervals.or(Intervals.term("b"), XIntervals.wildcard(new BytesRef("a*")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertEquals(Result.UNKNOWN, result);
source = Intervals.ordered(Intervals.term("term1"), XIntervals.prefix(new BytesRef("a")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "term1"));
source = Intervals.ordered(XIntervals.prefix(new BytesRef("a")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertEquals(Result.UNKNOWN, result);
source = Intervals.or(Intervals.term("b"), XIntervals.prefix(new BytesRef("a")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertEquals(Result.UNKNOWN, result);
source = Intervals.containedBy(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(3));
assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c"));
source = Intervals.containing(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(3));
assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c"));
source = Intervals.overlapping(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(3));
assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c"));
source = Intervals.within(Intervals.term("a"), 2, Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(3));
assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c"));
source = Intervals.notContainedBy(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "a"));
source = Intervals.notContaining(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "a"));
source = Intervals.nonOverlapping(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "a"));
source = Intervals.notWithin(Intervals.term("a"), 2, Intervals.ordered(Intervals.term("b"), Intervals.term("c")));
result = analyze(new IntervalQuery("field", source), Version.CURRENT);
assertThat(result.verified, is(false));
assertThat(result.matchAllDocs, is(false));
assertThat(result.minimumShouldMatch, equalTo(1));
assertTermsEqual(result.extractions, new Term("field", "a"));
}
}

View File

@ -30,6 +30,7 @@ import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.InPlaceMergeSorter;
@ -39,6 +40,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* BlendedTermQuery can be used to unify term statistics across
@ -245,6 +248,17 @@ public abstract class BlendedTermQuery extends Query {
return builder.toString();
}
@Override
public void visit(QueryVisitor visitor) {
Set<String> fields = Arrays.stream(terms).map(Term::field).collect(Collectors.toSet());
for (String field : fields) {
if (visitor.acceptField(field) == false) {
return;
}
}
visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this).consumeTerms(this, terms);
}
private class TermAndBoost implements Comparable<TermAndBoost> {
protected final Term term;
protected float boost;

View File

@ -0,0 +1,797 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.lucene.queries;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.queries.intervals.IntervalIterator;
import org.apache.lucene.queries.intervals.IntervalQuery;
import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.queries.intervals.IntervalsSource;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchesIterator;
import org.apache.lucene.search.MatchesUtils;
import org.apache.lucene.search.PrefixQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
/**
* Replacement for {@link Intervals#wildcard(BytesRef)} and {@link Intervals#prefix(BytesRef)}
* until LUCENE-9050 is merged
*/
public final class XIntervals {
private XIntervals() {}
public static IntervalsSource wildcard(BytesRef wildcard) {
CompiledAutomaton ca = new CompiledAutomaton(WildcardQuery.toAutomaton(new Term("", wildcard)));
return new MultiTermIntervalsSource(ca, 128, wildcard.utf8ToString());
}
public static IntervalsSource prefix(BytesRef prefix) {
CompiledAutomaton ca = new CompiledAutomaton(PrefixQuery.toAutomaton(prefix));
return new MultiTermIntervalsSource(ca, 128, prefix.utf8ToString());
}
static class MultiTermIntervalsSource extends IntervalsSource {
private final CompiledAutomaton automaton;
private final int maxExpansions;
private final String pattern;
MultiTermIntervalsSource(CompiledAutomaton automaton, int maxExpansions, String pattern) {
this.automaton = automaton;
if (maxExpansions > BooleanQuery.getMaxClauseCount()) {
throw new IllegalArgumentException("maxExpansions [" + maxExpansions
+ "] cannot be greater than BooleanQuery.getMaxClauseCount [" + BooleanQuery.getMaxClauseCount() + "]");
}
this.maxExpansions = maxExpansions;
this.pattern = pattern;
}
@Override
public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
Terms terms = ctx.reader().terms(field);
if (terms == null) {
return null;
}
List<IntervalIterator> subSources = new ArrayList<>();
TermsEnum te = automaton.getTermsEnum(terms);
BytesRef term;
int count = 0;
while ((term = te.next()) != null) {
subSources.add(TermIntervalsSource.intervals(term, te));
if (++count > maxExpansions) {
throw new IllegalStateException("Automaton [" + this.pattern + "] expanded to too many terms (limit "
+ maxExpansions + ")");
}
}
if (subSources.size() == 0) {
return null;
}
return new DisjunctionIntervalIterator(subSources);
}
@Override
public MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException {
Terms terms = ctx.reader().terms(field);
if (terms == null) {
return null;
}
List<MatchesIterator> subMatches = new ArrayList<>();
TermsEnum te = automaton.getTermsEnum(terms);
BytesRef term;
int count = 0;
while ((term = te.next()) != null) {
MatchesIterator mi = XIntervals.TermIntervalsSource.matches(te, doc);
if (mi != null) {
subMatches.add(mi);
if (count++ > maxExpansions) {
throw new IllegalStateException("Automaton " + term + " expanded to too many terms (limit " + maxExpansions + ")");
}
}
}
return MatchesUtils.disjunction(subMatches);
}
@Override
public void visit(String field, QueryVisitor visitor) {
visitor.visitLeaf(new IntervalQuery(field, this));
}
@Override
public int minExtent() {
return 1;
}
@Override
public Collection<IntervalsSource> pullUpDisjunctions() {
return Collections.singleton(this);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MultiTermIntervalsSource that = (MultiTermIntervalsSource) o;
return maxExpansions == that.maxExpansions &&
Objects.equals(automaton, that.automaton) &&
Objects.equals(pattern, that.pattern);
}
@Override
public int hashCode() {
return Objects.hash(automaton, maxExpansions, pattern);
}
@Override
public String toString() {
return "MultiTerm(" + pattern + ")";
}
}
static class DisiWrapper {
public final DocIdSetIterator iterator;
public final IntervalIterator intervals;
public final long cost;
public final float matchCost; // the match cost for two-phase iterators, 0 otherwise
public int doc; // the current doc, used for comparison
public DisiWrapper next; // reference to a next element, see #topList
// An approximation of the iterator, or the iterator itself if it does not
// support two-phase iteration
public final DocIdSetIterator approximation;
DisiWrapper(IntervalIterator iterator) {
this.intervals = iterator;
this.iterator = iterator;
this.cost = iterator.cost();
this.doc = -1;
this.approximation = iterator;
this.matchCost = iterator.matchCost();
}
}
static final class DisiPriorityQueue implements Iterable<DisiWrapper> {
static int leftNode(int node) {
return ((node + 1) << 1) - 1;
}
static int rightNode(int leftNode) {
return leftNode + 1;
}
static int parentNode(int node) {
return ((node + 1) >>> 1) - 1;
}
private final DisiWrapper[] heap;
private int size;
DisiPriorityQueue(int maxSize) {
heap = new DisiWrapper[maxSize];
size = 0;
}
public int size() {
return size;
}
public DisiWrapper top() {
return heap[0];
}
/** Get the list of scorers which are on the current doc. */
DisiWrapper topList() {
final DisiWrapper[] heap = this.heap;
final int size = this.size;
DisiWrapper list = heap[0];
list.next = null;
if (size >= 3) {
list = topList(list, heap, size, 1);
list = topList(list, heap, size, 2);
} else if (size == 2 && heap[1].doc == list.doc) {
list = prepend(heap[1], list);
}
return list;
}
// prepend w1 (iterator) to w2 (list)
private DisiWrapper prepend(DisiWrapper w1, DisiWrapper w2) {
w1.next = w2;
return w1;
}
private DisiWrapper topList(DisiWrapper list, DisiWrapper[] heap,
int size, int i) {
final DisiWrapper w = heap[i];
if (w.doc == list.doc) {
list = prepend(w, list);
final int left = leftNode(i);
final int right = left + 1;
if (right < size) {
list = topList(list, heap, size, left);
list = topList(list, heap, size, right);
} else if (left < size && heap[left].doc == list.doc) {
list = prepend(heap[left], list);
}
}
return list;
}
public DisiWrapper add(DisiWrapper entry) {
final DisiWrapper[] heap = this.heap;
final int size = this.size;
heap[size] = entry;
upHeap(size);
this.size = size + 1;
return heap[0];
}
public DisiWrapper pop() {
final DisiWrapper[] heap = this.heap;
final DisiWrapper result = heap[0];
final int i = --size;
heap[0] = heap[i];
heap[i] = null;
downHeap(i);
return result;
}
DisiWrapper updateTop() {
downHeap(size);
return heap[0];
}
void upHeap(int i) {
final DisiWrapper node = heap[i];
final int nodeDoc = node.doc;
int j = parentNode(i);
while (j >= 0 && nodeDoc < heap[j].doc) {
heap[i] = heap[j];
i = j;
j = parentNode(j);
}
heap[i] = node;
}
void downHeap(int size) {
int i = 0;
final DisiWrapper node = heap[0];
int j = leftNode(i);
if (j < size) {
int k = rightNode(j);
if (k < size && heap[k].doc < heap[j].doc) {
j = k;
}
if (heap[j].doc < node.doc) {
do {
heap[i] = heap[j];
i = j;
j = leftNode(i);
k = rightNode(j);
if (k < size && heap[k].doc < heap[j].doc) {
j = k;
}
} while (j < size && heap[j].doc < node.doc);
heap[i] = node;
}
}
}
@Override
public Iterator<DisiWrapper> iterator() {
return Arrays.asList(heap).subList(0, size).iterator();
}
}
static class DisjunctionDISIApproximation extends DocIdSetIterator {
final DisiPriorityQueue subIterators;
final long cost;
DisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
this.subIterators = subIterators;
long cost = 0;
for (DisiWrapper w : subIterators) {
cost += w.cost;
}
this.cost = cost;
}
@Override
public long cost() {
return cost;
}
@Override
public int docID() {
return subIterators.top().doc;
}
@Override
public int nextDoc() throws IOException {
DisiWrapper top = subIterators.top();
final int doc = top.doc;
do {
top.doc = top.approximation.nextDoc();
top = subIterators.updateTop();
} while (top.doc == doc);
return top.doc;
}
@Override
public int advance(int target) throws IOException {
DisiWrapper top = subIterators.top();
do {
top.doc = top.approximation.advance(target);
top = subIterators.updateTop();
} while (top.doc < target);
return top.doc;
}
}
static class DisjunctionIntervalIterator extends IntervalIterator {
final DocIdSetIterator approximation;
final PriorityQueue<IntervalIterator> intervalQueue;
final DisiPriorityQueue disiQueue;
final List<IntervalIterator> iterators;
final float matchCost;
IntervalIterator current = EMPTY;
DisjunctionIntervalIterator(List<IntervalIterator> iterators) {
this.disiQueue = new DisiPriorityQueue(iterators.size());
for (IntervalIterator it : iterators) {
disiQueue.add(new DisiWrapper(it));
}
this.approximation = new DisjunctionDISIApproximation(disiQueue);
this.iterators = iterators;
this.intervalQueue = new PriorityQueue<IntervalIterator>(iterators.size()) {
@Override
protected boolean lessThan(IntervalIterator a, IntervalIterator b) {
return a.end() < b.end() || (a.end() == b.end() && a.start() >= b.start());
}
};
float costsum = 0;
for (IntervalIterator it : iterators) {
costsum += it.cost();
}
this.matchCost = costsum;
}
@Override
public float matchCost() {
return matchCost;
}
@Override
public int start() {
return current.start();
}
@Override
public int end() {
return current.end();
}
@Override
public int gaps() {
return current.gaps();
}
private void reset() throws IOException {
intervalQueue.clear();
for (DisiWrapper dw = disiQueue.topList(); dw != null; dw = dw.next) {
dw.intervals.nextInterval();
intervalQueue.add(dw.intervals);
}
current = EMPTY;
}
@Override
public int nextInterval() throws IOException {
if (current == EMPTY || current == EXHAUSTED) {
if (intervalQueue.size() > 0) {
current = intervalQueue.top();
}
return current.start();
}
int start = current.start(), end = current.end();
while (intervalQueue.size() > 0 && contains(intervalQueue.top(), start, end)) {
IntervalIterator it = intervalQueue.pop();
if (it != null && it.nextInterval() != NO_MORE_INTERVALS) {
intervalQueue.add(it);
}
}
if (intervalQueue.size() == 0) {
current = EXHAUSTED;
return NO_MORE_INTERVALS;
}
current = intervalQueue.top();
return current.start();
}
private boolean contains(IntervalIterator it, int start, int end) {
return start >= it.start() && start <= it.end() && end >= it.start() && end <= it.end();
}
@Override
public int docID() {
return approximation.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = approximation.nextDoc();
reset();
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = approximation.advance(target);
reset();
return doc;
}
@Override
public long cost() {
return approximation.cost();
}
}
private static final IntervalIterator EMPTY = new IntervalIterator() {
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public int start() {
return -1;
}
@Override
public int end() {
return -1;
}
@Override
public int gaps() {
throw new UnsupportedOperationException();
}
@Override
public int nextInterval() {
return NO_MORE_INTERVALS;
}
@Override
public float matchCost() {
return 0;
}
};
private static final IntervalIterator EXHAUSTED = new IntervalIterator() {
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public int start() {
return NO_MORE_INTERVALS;
}
@Override
public int end() {
return NO_MORE_INTERVALS;
}
@Override
public int gaps() {
throw new UnsupportedOperationException();
}
@Override
public int nextInterval() {
return NO_MORE_INTERVALS;
}
@Override
public float matchCost() {
return 0;
}
};
static class TermIntervalsSource extends IntervalsSource {
final BytesRef term;
TermIntervalsSource(BytesRef term) {
this.term = term;
}
@Override
public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
Terms terms = ctx.reader().terms(field);
if (terms == null)
return null;
if (terms.hasPositions() == false) {
throw new IllegalArgumentException("Cannot create an IntervalIterator over field " + field
+ " because it has no indexed positions");
}
TermsEnum te = terms.iterator();
if (te.seekExact(term) == false) {
return null;
}
return intervals(term, te);
}
static IntervalIterator intervals(BytesRef term, TermsEnum te) throws IOException {
PostingsEnum pe = te.postings(null, PostingsEnum.POSITIONS);
float cost = termPositionsCost(te);
return new IntervalIterator() {
@Override
public int docID() {
return pe.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = pe.nextDoc();
reset();
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = pe.advance(target);
reset();
return doc;
}
@Override
public long cost() {
return pe.cost();
}
int pos = -1, upto;
@Override
public int start() {
return pos;
}
@Override
public int end() {
return pos;
}
@Override
public int gaps() {
return 0;
}
@Override
public int nextInterval() throws IOException {
if (upto <= 0)
return pos = NO_MORE_INTERVALS;
upto--;
return pos = pe.nextPosition();
}
@Override
public float matchCost() {
return cost;
}
private void reset() throws IOException {
if (pe.docID() == NO_MORE_DOCS) {
upto = -1;
pos = NO_MORE_INTERVALS;
}
else {
upto = pe.freq();
pos = -1;
}
}
@Override
public String toString() {
return term.utf8ToString() + ":" + super.toString();
}
};
}
@Override
public MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException {
Terms terms = ctx.reader().terms(field);
if (terms == null)
return null;
if (terms.hasPositions() == false) {
throw new IllegalArgumentException("Cannot create an IntervalIterator over field " + field
+ " because it has no indexed positions");
}
TermsEnum te = terms.iterator();
if (te.seekExact(term) == false) {
return null;
}
return matches(te, doc);
}
static MatchesIterator matches(TermsEnum te, int doc) throws IOException {
PostingsEnum pe = te.postings(null, PostingsEnum.OFFSETS);
if (pe.advance(doc) != doc) {
return null;
}
return new MatchesIterator() {
int upto = pe.freq();
int pos = -1;
@Override
public boolean next() throws IOException {
if (upto <= 0) {
pos = IntervalIterator.NO_MORE_INTERVALS;
return false;
}
upto--;
pos = pe.nextPosition();
return true;
}
@Override
public int startPosition() {
return pos;
}
@Override
public int endPosition() {
return pos;
}
@Override
public int startOffset() throws IOException {
return pe.startOffset();
}
@Override
public int endOffset() throws IOException {
return pe.endOffset();
}
@Override
public MatchesIterator getSubMatches() {
return null;
}
@Override
public Query getQuery() {
throw new UnsupportedOperationException();
}
};
}
@Override
public int minExtent() {
return 1;
}
@Override
public Collection<IntervalsSource> pullUpDisjunctions() {
return Collections.singleton(this);
}
@Override
public int hashCode() {
return Objects.hash(term);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TermIntervalsSource that = (TermIntervalsSource) o;
return Objects.equals(term, that.term);
}
@Override
public String toString() {
return term.utf8ToString();
}
@Override
public void visit(String field, QueryVisitor visitor) {
visitor.consumeTerms(new IntervalQuery(field, this), new Term(field, term));
}
private static final int TERM_POSNS_SEEK_OPS_PER_DOC = 128;
private static final int TERM_OPS_PER_POS = 7;
static float termPositionsCost(TermsEnum termsEnum) throws IOException {
int docFreq = termsEnum.docFreq();
assert docFreq > 0;
long totalTermFreq = termsEnum.totalTermFreq();
float expOccurrencesInMatchingDoc = totalTermFreq / (float) docFreq;
return TERM_POSNS_SEEK_OPS_PER_DOC + expOccurrencesInMatchingDoc * TERM_OPS_PER_POS;
}
}
}

View File

@ -33,6 +33,7 @@ import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.XIntervals;
import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.queries.intervals.IntervalsSource;
import org.apache.lucene.search.AutomatonQuery;
@ -426,7 +427,7 @@ public class TextFieldMapper extends FieldMapper {
public IntervalsSource intervals(BytesRef term) {
if (term.length > maxChars) {
return Intervals.prefix(term);
return XIntervals.prefix(term);
}
if (term.length >= minChars) {
return Intervals.fixField(name(), Intervals.term(term));
@ -436,7 +437,7 @@ public class TextFieldMapper extends FieldMapper {
sb.append("?");
}
String wildcardTerm = sb.toString();
return Intervals.or(Intervals.fixField(name(), Intervals.wildcard(new BytesRef(wildcardTerm))), Intervals.term(term));
return Intervals.or(Intervals.fixField(name(), XIntervals.wildcard(new BytesRef(wildcardTerm))), Intervals.term(term));
}
@Override
@ -680,7 +681,7 @@ public class TextFieldMapper extends FieldMapper {
if (prefixFieldType != null) {
return prefixFieldType.intervals(normalizedTerm);
}
return Intervals.prefix(normalizedTerm);
return XIntervals.prefix(normalizedTerm);
}
IntervalBuilder builder = new IntervalBuilder(name(), analyzer == null ? searchAnalyzer() : analyzer);
return builder.analyzeText(text, maxGaps, ordered);

View File

@ -20,6 +20,7 @@
package org.elasticsearch.index.query;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.queries.XIntervals;
import org.apache.lucene.queries.intervals.FilteredIntervalsSource;
import org.apache.lucene.queries.intervals.IntervalIterator;
import org.apache.lucene.queries.intervals.Intervals;
@ -585,12 +586,12 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont
}
BytesRef normalizedTerm = analyzer.normalize(useField, pattern);
// TODO Intervals.wildcard() should take BytesRef
source = Intervals.fixField(useField, Intervals.wildcard(normalizedTerm));
source = Intervals.fixField(useField, XIntervals.wildcard(normalizedTerm));
}
else {
checkPositions(fieldType);
BytesRef normalizedTerm = analyzer.normalize(fieldType.name(), pattern);
source = Intervals.wildcard(normalizedTerm);
source = XIntervals.wildcard(normalizedTerm);
}
return source;
}

View File

@ -19,11 +19,12 @@
package org.elasticsearch.index.query;
import org.apache.lucene.queries.XIntervals;
import org.apache.lucene.queries.intervals.IntervalQuery;
import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.queries.intervals.IntervalQuery;
import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.Strings;
@ -395,7 +396,7 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
String json = "{ \"intervals\" : { \"" + STRING_FIELD_NAME + "\": { " +
"\"prefix\" : { \"prefix\" : \"term\" } } } }";
IntervalQueryBuilder builder = (IntervalQueryBuilder) parseQuery(json);
Query expected = new IntervalQuery(STRING_FIELD_NAME, Intervals.prefix(new BytesRef("term")));
Query expected = new IntervalQuery(STRING_FIELD_NAME, XIntervals.prefix(new BytesRef("term")));
assertEquals(expected, builder.toQuery(createShardContext()));
String no_positions_json = "{ \"intervals\" : { \"" + NO_POSITIONS_FIELD + "\": { " +
@ -422,7 +423,7 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
"\"prefix\" : { \"prefix\" : \"t\" } } } }";
builder = (IntervalQueryBuilder) parseQuery(short_prefix_json);
expected = new IntervalQuery(PREFIXED_FIELD, Intervals.or(
Intervals.fixField(PREFIXED_FIELD + "._index_prefix", Intervals.wildcard(new BytesRef("t?"))),
Intervals.fixField(PREFIXED_FIELD + "._index_prefix", XIntervals.wildcard(new BytesRef("t?"))),
Intervals.term("t")));
assertEquals(expected, builder.toQuery(createShardContext()));
@ -454,7 +455,7 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
"\"wildcard\" : { \"pattern\" : \"Te?m\" } } } }";
IntervalQueryBuilder builder = (IntervalQueryBuilder) parseQuery(json);
Query expected = new IntervalQuery(STRING_FIELD_NAME, Intervals.wildcard(new BytesRef("te?m")));
Query expected = new IntervalQuery(STRING_FIELD_NAME, XIntervals.wildcard(new BytesRef("te?m")));
assertEquals(expected, builder.toQuery(createShardContext()));
String no_positions_json = "{ \"intervals\" : { \"" + NO_POSITIONS_FIELD + "\": { " +
@ -468,14 +469,14 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
"\"wildcard\" : { \"pattern\" : \"Te?m\", \"analyzer\" : \"keyword\" } } } }";
builder = (IntervalQueryBuilder) parseQuery(keyword_json);
expected = new IntervalQuery(STRING_FIELD_NAME, Intervals.wildcard(new BytesRef("Te?m")));
expected = new IntervalQuery(STRING_FIELD_NAME, XIntervals.wildcard(new BytesRef("Te?m")));
assertEquals(expected, builder.toQuery(createShardContext()));
String fixed_field_json = "{ \"intervals\" : { \"" + STRING_FIELD_NAME + "\": { " +
"\"wildcard\" : { \"pattern\" : \"Te?m\", \"use_field\" : \"masked_field\" } } } }";
builder = (IntervalQueryBuilder) parseQuery(fixed_field_json);
expected = new IntervalQuery(STRING_FIELD_NAME, Intervals.fixField(MASKED_FIELD, Intervals.wildcard(new BytesRef("te?m"))));
expected = new IntervalQuery(STRING_FIELD_NAME, Intervals.fixField(MASKED_FIELD, XIntervals.wildcard(new BytesRef("te?m"))));
assertEquals(expected, builder.toQuery(createShardContext()));
String fixed_field_json_no_positions = "{ \"intervals\" : { \"" + STRING_FIELD_NAME + "\": { " +
@ -490,7 +491,7 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase<IntervalQue
builder = (IntervalQueryBuilder) parseQuery(fixed_field_analyzer_json);
expected = new IntervalQuery(STRING_FIELD_NAME, Intervals.fixField(MASKED_FIELD,
Intervals.wildcard(new BytesRef("Te?m"))));
XIntervals.wildcard(new BytesRef("Te?m"))));
assertEquals(expected, builder.toQuery(createShardContext()));
}