From c6b31162ba152547cdfd9c7dc8074758a69adc78 Mon Sep 17 00:00:00 2001 From: Alan Woodward Date: Wed, 20 Nov 2019 09:21:01 +0000 Subject: [PATCH] Refactor percolator's QueryAnalyzer to use QueryVisitors Lucene now allows us to explore the structure of a query using QueryVisitors, delegating the knowledge of how to recurse through and collect terms to the query implementations themselves. The percolator currently has a home-grown external version of this API to construct sets of matching terms that must be present in a document in order for it to possibly match the query. This commit removes the home-grown implementation in favour of one using QueryVisitor. This has the added benefit of making interval queries available for percolator pre-filtering. Due to a bug in multi-term intervals (LUCENE-9050) it also includes a clone of some of the lucene intervals logic, that can be removed once upstream has been fixed. Closes #45639 --- .../percolator/PercolatorFieldMapper.java | 7 +- .../percolator/QueryAnalyzer.java | 591 ++++--------- .../PercolatorFieldMapperTests.java | 4 +- .../percolator/QueryAnalyzerTests.java | 138 ++- .../lucene/queries/BlendedTermQuery.java | 14 + .../org/apache/lucene/queries/XIntervals.java | 797 ++++++++++++++++++ .../index/mapper/TextFieldMapper.java | 7 +- .../index/query/IntervalsSourceProvider.java | 5 +- .../query/IntervalQueryBuilderTests.java | 17 +- 9 files changed, 1134 insertions(+), 446 deletions(-) create mode 100644 server/src/main/java/org/apache/lucene/queries/XIntervals.java diff --git a/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorFieldMapper.java b/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorFieldMapper.java index 40a9e995ad2..ad5718d2d13 100644 --- a/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorFieldMapper.java +++ b/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorFieldMapper.java @@ -446,10 +446,9 @@ public class PercolatorFieldMapper extends FieldMapper { ParseContext.Document doc = context.doc(); FieldType pft = (FieldType) this.fieldType(); QueryAnalyzer.Result result; - try { - Version indexVersion = context.mapperService().getIndexSettings().getIndexVersionCreated(); - result = QueryAnalyzer.analyze(query, indexVersion); - } catch (QueryAnalyzer.UnsupportedQueryException e) { + Version indexVersion = context.mapperService().getIndexSettings().getIndexVersionCreated(); + result = QueryAnalyzer.analyze(query, indexVersion); + if (result == QueryAnalyzer.Result.UNKNOWN) { doc.add(new Field(pft.extractionResultField.name(), EXTRACTION_FAILED, extractionResultField.fieldType())); return; } diff --git a/modules/percolator/src/main/java/org/elasticsearch/percolator/QueryAnalyzer.java b/modules/percolator/src/main/java/org/elasticsearch/percolator/QueryAnalyzer.java index ebebfa01b67..7048106d4af 100644 --- a/modules/percolator/src/main/java/org/elasticsearch/percolator/QueryAnalyzer.java +++ b/modules/percolator/src/main/java/org/elasticsearch/percolator/QueryAnalyzer.java @@ -19,84 +19,41 @@ package org.elasticsearch.percolator; import org.apache.lucene.document.BinaryRange; -import org.apache.lucene.index.PrefixCodedTerms; import org.apache.lucene.index.Term; import org.apache.lucene.queries.BlendedTermQuery; -import org.apache.lucene.queries.CommonTermsQuery; -import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DisjunctionMaxQuery; -import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.MultiPhraseQuery; -import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.SynonymQuery; import org.apache.lucene.search.TermInSetQuery; import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.spans.SpanFirstQuery; -import org.apache.lucene.search.spans.SpanNearQuery; -import org.apache.lucene.search.spans.SpanNotQuery; import org.apache.lucene.search.spans.SpanOrQuery; -import org.apache.lucene.search.spans.SpanQuery; import org.apache.lucene.search.spans.SpanTermQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.NumericUtils; import org.elasticsearch.Version; -import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; +import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; -import java.util.function.BiFunction; import java.util.stream.Collectors; -import static java.util.stream.Collectors.toSet; - final class QueryAnalyzer { - private static final Map, BiFunction> queryProcessors; - - static { - Map, BiFunction> map = new HashMap<>(); - map.put(MatchNoDocsQuery.class, matchNoDocsQuery()); - map.put(MatchAllDocsQuery.class, matchAllDocsQuery()); - map.put(ConstantScoreQuery.class, constantScoreQuery()); - map.put(BoostQuery.class, boostQuery()); - map.put(TermQuery.class, termQuery()); - map.put(TermInSetQuery.class, termInSetQuery()); - map.put(CommonTermsQuery.class, commonTermsQuery()); - map.put(BlendedTermQuery.class, blendedTermQuery()); - map.put(PhraseQuery.class, phraseQuery()); - map.put(MultiPhraseQuery.class, multiPhraseQuery()); - map.put(SpanTermQuery.class, spanTermQuery()); - map.put(SpanNearQuery.class, spanNearQuery()); - map.put(SpanOrQuery.class, spanOrQuery()); - map.put(SpanFirstQuery.class, spanFirstQuery()); - map.put(SpanNotQuery.class, spanNotQuery()); - map.put(BooleanQuery.class, booleanQuery()); - map.put(DisjunctionMaxQuery.class, disjunctionMaxQuery()); - map.put(SynonymQuery.class, synonymQuery()); - map.put(FunctionScoreQuery.class, functionScoreQuery()); - map.put(PointRangeQuery.class, pointRangeQuery()); - map.put(IndexOrDocValuesQuery.class, indexOrDocValuesQuery()); - map.put(ESToParentBlockJoinQuery.class, toParentBlockJoinQuery()); - queryProcessors = Collections.unmodifiableMap(map); - } - private QueryAnalyzer() { } @@ -128,299 +85,155 @@ final class QueryAnalyzer { * @param indexVersion The create version of the index containing the percolator queries. */ static Result analyze(Query query, Version indexVersion) { - Class queryClass = query.getClass(); - if (queryClass.isAnonymousClass()) { - // Sometimes queries have anonymous classes in that case we need the direct super class. - // (for example blended term query) - queryClass = queryClass.getSuperclass(); + ResultBuilder builder = new ResultBuilder(indexVersion, false); + query.visit(builder); + return builder.getResult(); + } + + private static final Set> verifiedQueries = new HashSet<>(Arrays.asList( + TermQuery.class, TermInSetQuery.class, SynonymQuery.class, SpanTermQuery.class, SpanOrQuery.class, + BooleanQuery.class, DisjunctionMaxQuery.class, ConstantScoreQuery.class, BoostQuery.class, + BlendedTermQuery.class + )); + + private static boolean isVerified(Query query) { + if (query instanceof FunctionScoreQuery) { + return ((FunctionScoreQuery)query).getMinScore() == null; } - BiFunction queryProcessor = queryProcessors.get(queryClass); - if (queryProcessor != null) { - return queryProcessor.apply(query, indexVersion); - } else { - throw new UnsupportedQueryException(query); + for (Class cls : verifiedQueries) { + if (cls.isAssignableFrom(query.getClass())) { + return true; + } } + return false; } - private static BiFunction matchNoDocsQuery() { - return (query, version) -> new Result(true, Collections.emptySet(), 0); - } + private static class ResultBuilder extends QueryVisitor { - private static BiFunction matchAllDocsQuery() { - return (query, version) -> new Result(true, true); - } + final boolean conjunction; + final Version version; + List children = new ArrayList<>(); + boolean verified = true; + int minimumShouldMatch = 0; + List terms = new ArrayList<>(); - private static BiFunction constantScoreQuery() { - return (query, boosts) -> { - Query wrappedQuery = ((ConstantScoreQuery) query).getQuery(); - return analyze(wrappedQuery, boosts); - }; - } + private ResultBuilder(Version version, boolean conjunction) { + this.conjunction = conjunction; + this.version = version; + } - private static BiFunction boostQuery() { - return (query, version) -> { - Query wrappedQuery = ((BoostQuery) query).getQuery(); - return analyze(wrappedQuery, version); - }; - } + @Override + public String toString() { + return (conjunction ? "CONJ" : "DISJ") + children + terms + "~" + minimumShouldMatch; + } - private static BiFunction termQuery() { - return (query, version) -> { - TermQuery termQuery = (TermQuery) query; - return new Result(true, Collections.singleton(new QueryExtraction(termQuery.getTerm())), 1); - }; - } - - private static BiFunction termInSetQuery() { - return (query, version) -> { - TermInSetQuery termInSetQuery = (TermInSetQuery) query; - Set terms = new HashSet<>(); - PrefixCodedTerms.TermIterator iterator = termInSetQuery.getTermData().iterator(); - for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { - terms.add(new QueryExtraction(new Term(iterator.field(), term))); + Result getResult() { + List partialResults = new ArrayList<>(); + if (terms.size() > 0) { + partialResults.add(conjunction ? handleConjunction(terms, version) : + handleDisjunction(terms, minimumShouldMatch, version)); } - return new Result(true, terms, Math.min(1, terms.size())); - }; - } - - private static BiFunction synonymQuery() { - return (query, version) -> { - Set terms = ((SynonymQuery) query).getTerms().stream().map(QueryExtraction::new).collect(toSet()); - return new Result(true, terms, Math.min(1, terms.size())); - }; - } - - private static BiFunction commonTermsQuery() { - return (query, version) -> { - Set terms = ((CommonTermsQuery) query).getTerms().stream().map(QueryExtraction::new).collect(toSet()); - return new Result(false, terms, Math.min(1, terms.size())); - }; - } - - private static BiFunction blendedTermQuery() { - return (query, version) -> { - Set terms = ((BlendedTermQuery) query).getTerms().stream().map(QueryExtraction::new).collect(toSet()); - return new Result(true, terms, Math.min(1, terms.size())); - }; - } - - private static BiFunction phraseQuery() { - return (query, version) -> { - Term[] terms = ((PhraseQuery) query).getTerms(); - if (terms.length == 0) { - return new Result(true, Collections.emptySet(), 0); + if (children.isEmpty() == false) { + List childResults = children.stream().map(ResultBuilder::getResult).collect(Collectors.toList()); + partialResults.addAll(childResults); } - if (version.onOrAfter(Version.V_6_1_0)) { - Set extractions = Arrays.stream(terms).map(QueryExtraction::new).collect(toSet()); - return new Result(false, extractions, extractions.size()); - } else { - // the longest term is likely to be the rarest, - // so from a performance perspective it makes sense to extract that - Term longestTerm = terms[0]; - for (Term term : terms) { - if (longestTerm.bytes().length < term.bytes().length) { - longestTerm = term; - } - } - return new Result(false, Collections.singleton(new QueryExtraction(longestTerm)), 1); + if (partialResults.isEmpty()) { + return verified ? Result.MATCH_NONE : Result.UNKNOWN; } - }; - } - - private static BiFunction multiPhraseQuery() { - return (query, version) -> { - Term[][] terms = ((MultiPhraseQuery) query).getTermArrays(); - if (terms.length == 0) { - return new Result(true, Collections.emptySet(), 0); - } - - // This query has the same problem as boolean queries when it comes to duplicated terms - // So to keep things simple, we just rewrite to a boolean query - BooleanQuery.Builder builder = new BooleanQuery.Builder(); - for (Term[] termArr : terms) { - BooleanQuery.Builder subBuilder = new BooleanQuery.Builder(); - for (Term term : termArr) { - subBuilder.add(new TermQuery(term), Occur.SHOULD); - } - builder.add(subBuilder.build(), Occur.FILTER); - } - // Make sure to unverify the result - return booleanQuery().apply(builder.build(), version).unverify(); - }; - } - - private static BiFunction spanTermQuery() { - return (query, version) -> { - Term term = ((SpanTermQuery) query).getTerm(); - return new Result(true, Collections.singleton(new QueryExtraction(term)), 1); - }; - } - - private static BiFunction spanNearQuery() { - return (query, version) -> { - SpanNearQuery spanNearQuery = (SpanNearQuery) query; - if (version.onOrAfter(Version.V_6_1_0)) { - // This has the same problem as boolean queries when it comes to duplicated clauses - // so we rewrite to a boolean query to keep things simple. - BooleanQuery.Builder builder = new BooleanQuery.Builder(); - for (SpanQuery clause : spanNearQuery.getClauses()) { - builder.add(clause, Occur.FILTER); - } - // make sure to unverify the result - return booleanQuery().apply(builder.build(), version).unverify(); - } else { - Result bestClause = null; - for (SpanQuery clause : spanNearQuery.getClauses()) { - Result temp = analyze(clause, version); - bestClause = selectBestResult(temp, bestClause); - } - return bestClause; - } - }; - } - - private static BiFunction spanOrQuery() { - return (query, version) -> { - SpanOrQuery spanOrQuery = (SpanOrQuery) query; - // handle it like a boolean query to not dulplicate eg. logic - // about duplicated terms - BooleanQuery.Builder builder = new BooleanQuery.Builder(); - for (SpanQuery clause : spanOrQuery.getClauses()) { - builder.add(clause, Occur.SHOULD); - } - return booleanQuery().apply(builder.build(), version); - }; - } - - private static BiFunction spanNotQuery() { - return (query, version) -> { - Result result = analyze(((SpanNotQuery) query).getInclude(), version); - return new Result(false, result.extractions, result.minimumShouldMatch); - }; - } - - private static BiFunction spanFirstQuery() { - return (query, version) -> { - Result result = analyze(((SpanFirstQuery) query).getMatch(), version); - return new Result(false, result.extractions, result.minimumShouldMatch); - }; - } - - private static BiFunction booleanQuery() { - return (query, version) -> { - BooleanQuery bq = (BooleanQuery) query; - int minimumShouldMatch = bq.getMinimumNumberShouldMatch(); - List requiredClauses = new ArrayList<>(); - List optionalClauses = new ArrayList<>(); - boolean hasProhibitedClauses = false; - for (BooleanClause clause : bq.clauses()) { - if (clause.isRequired()) { - requiredClauses.add(clause.getQuery()); - } else if (clause.isProhibited()) { - hasProhibitedClauses = true; - } else { - assert clause.getOccur() == Occur.SHOULD; - optionalClauses.add(clause.getQuery()); - } - } - - if (minimumShouldMatch > optionalClauses.size() - || (requiredClauses.isEmpty() && optionalClauses.isEmpty())) { - return new Result(false, Collections.emptySet(), 0); - } - - if (requiredClauses.size() > 0) { - if (minimumShouldMatch > 0) { - // mix of required clauses and required optional clauses, we turn it into - // a pure conjunction by moving the optional clauses to a sub query to - // simplify logic - BooleanQuery.Builder minShouldMatchQuery = new BooleanQuery.Builder(); - minShouldMatchQuery.setMinimumNumberShouldMatch(minimumShouldMatch); - for (Query q : optionalClauses) { - minShouldMatchQuery.add(q, Occur.SHOULD); - } - requiredClauses.add(minShouldMatchQuery.build()); - optionalClauses.clear(); - minimumShouldMatch = 0; - } else { - optionalClauses.clear(); // only matter for scoring, not matching - } - } - - // Now we now have either a pure conjunction or a pure disjunction, with at least one clause Result result; - if (requiredClauses.size() > 0) { - assert optionalClauses.isEmpty(); - assert minimumShouldMatch == 0; - result = handleConjunctionQuery(requiredClauses, version); + if (partialResults.size() == 1) { + result = partialResults.get(0); } else { - assert requiredClauses.isEmpty(); - if (minimumShouldMatch == 0) { - // Lucene always requires one matching clause for disjunctions - minimumShouldMatch = 1; - } - result = handleDisjunctionQuery(optionalClauses, minimumShouldMatch, version); + result = conjunction ? handleConjunction(partialResults, version) + : handleDisjunction(partialResults, minimumShouldMatch, version); } - - if (hasProhibitedClauses) { + if (verified == false) { result = result.unverify(); } - return result; - }; + } + + @Override + public QueryVisitor getSubVisitor(Occur occur, Query parent) { + this.verified = isVerified(parent); + if (occur == Occur.MUST || occur == Occur.FILTER) { + ResultBuilder builder = new ResultBuilder(version, true); + children.add(builder); + return builder; + } + if (occur == Occur.MUST_NOT) { + this.verified = false; + return QueryVisitor.EMPTY_VISITOR; + } + int minimumShouldMatch = 0; + if (parent instanceof BooleanQuery) { + BooleanQuery bq = (BooleanQuery) parent; + if (bq.getMinimumNumberShouldMatch() == 0 + && bq.clauses().stream().anyMatch(c -> c.getOccur() == Occur.MUST || c.getOccur() == Occur.FILTER)) { + return QueryVisitor.EMPTY_VISITOR; + } + minimumShouldMatch = bq.getMinimumNumberShouldMatch(); + } + ResultBuilder child = new ResultBuilder(version, false); + child.minimumShouldMatch = minimumShouldMatch; + children.add(child); + return child; + } + + @Override + public void visitLeaf(Query query) { + if (query instanceof MatchAllDocsQuery) { + terms.add(new Result(true, true)); + } + else if (query instanceof MatchNoDocsQuery) { + terms.add(Result.MATCH_NONE); + } + else if (query instanceof PointRangeQuery) { + terms.add(pointRangeQuery((PointRangeQuery)query)); + } + else { + terms.add(Result.UNKNOWN); + } + } + + @Override + public void consumeTerms(Query query, Term... terms) { + boolean verified = isVerified(query); + Set qe = Arrays.stream(terms).map(QueryExtraction::new).collect(Collectors.toSet()); + if (qe.size() > 0) { + if (version.before(Version.V_6_1_0) && conjunction) { + Optional longest = qe.stream() + .filter(q -> q.term != null) + .max(Comparator.comparingInt(q -> q.term.bytes().length)); + if (longest.isPresent()) { + qe = Collections.singleton(longest.get()); + } + } + this.terms.add(new Result(verified, qe, conjunction ? qe.size() : 1)); + } + } + } - private static BiFunction disjunctionMaxQuery() { - return (query, version) -> { - List disjuncts = ((DisjunctionMaxQuery) query).getDisjuncts(); - if (disjuncts.isEmpty()) { - return new Result(false, Collections.emptySet(), 0); - } else { - return handleDisjunctionQuery(disjuncts, 1, version); - } - }; - } + private static Result pointRangeQuery(PointRangeQuery query) { + if (query.getNumDims() != 1) { + return Result.UNKNOWN; + } - private static BiFunction functionScoreQuery() { - return (query, version) -> { - FunctionScoreQuery functionScoreQuery = (FunctionScoreQuery) query; - Result result = analyze(functionScoreQuery.getSubQuery(), version); + byte[] lowerPoint = query.getLowerPoint(); + byte[] upperPoint = query.getUpperPoint(); - // If min_score is specified we can't guarantee upfront that this percolator query matches, - // so in that case we set verified to false. - // (if it matches with the percolator document matches with the extracted terms. - // Min score filters out docs, which is different than the functions, which just influences the score.) - boolean verified = result.verified && functionScoreQuery.getMinScore() == null; - if (result.matchAllDocs) { - return new Result(result.matchAllDocs, verified); - } else { - return new Result(verified, result.extractions, result.minimumShouldMatch); - } - }; - } + // Need to check whether upper is not smaller than lower, otherwise NumericUtils.subtract(...) fails IAE + // If upper is really smaller than lower then we deal with like MatchNoDocsQuery. (verified and no extractions) + if (new BytesRef(lowerPoint).compareTo(new BytesRef(upperPoint)) > 0) { + return new Result(true, Collections.emptySet(), 0); + } - private static BiFunction pointRangeQuery() { - return (query, version) -> { - PointRangeQuery pointRangeQuery = (PointRangeQuery) query; - if (pointRangeQuery.getNumDims() != 1) { - throw new UnsupportedQueryException(query); - } - - byte[] lowerPoint = pointRangeQuery.getLowerPoint(); - byte[] upperPoint = pointRangeQuery.getUpperPoint(); - - // Need to check whether upper is not smaller than lower, otherwise NumericUtils.subtract(...) fails IAE - // If upper is really smaller than lower then we deal with like MatchNoDocsQuery. (verified and no extractions) - if (new BytesRef(lowerPoint).compareTo(new BytesRef(upperPoint)) > 0) { - return new Result(true, Collections.emptySet(), 0); - } - - byte[] interval = new byte[16]; - NumericUtils.subtract(16, 0, prepad(upperPoint), prepad(lowerPoint), interval); - return new Result(false, Collections.singleton(new QueryExtraction( - new Range(pointRangeQuery.getField(), lowerPoint, upperPoint, interval))), 1); - }; + byte[] interval = new byte[16]; + NumericUtils.subtract(16, 0, prepad(upperPoint), prepad(lowerPoint), interval); + return new Result(false, Collections.singleton(new QueryExtraction( + new Range(query.getField(), lowerPoint, upperPoint, interval))), 1); } private static byte[] prepad(byte[] original) { @@ -430,58 +243,16 @@ final class QueryAnalyzer { return result; } - private static BiFunction indexOrDocValuesQuery() { - return (query, version) -> { - IndexOrDocValuesQuery indexOrDocValuesQuery = (IndexOrDocValuesQuery) query; - return analyze(indexOrDocValuesQuery.getIndexQuery(), version); - }; - } - - private static BiFunction toParentBlockJoinQuery() { - return (query, version) -> { - ESToParentBlockJoinQuery toParentBlockJoinQuery = (ESToParentBlockJoinQuery) query; - Result result = analyze(toParentBlockJoinQuery.getChildQuery(), version); - return new Result(false, result.extractions, result.minimumShouldMatch); - }; - } - - private static Result handleConjunctionQuery(List conjunctions, Version version) { - UnsupportedQueryException uqe = null; - List results = new ArrayList<>(conjunctions.size()); - boolean success = false; - for (Query query : conjunctions) { - try { - Result subResult = analyze(query, version); - if (subResult.isMatchNoDocs()) { - return subResult; - } - results.add(subResult); - success = true; - } catch (UnsupportedQueryException e) { - uqe = e; - } - } - - if (success == false) { - // No clauses could be extracted - if (uqe != null) { - - throw uqe; - } else { - // Empty conjunction - return new Result(true, Collections.emptySet(), 0); - } - } - Result result = handleConjunction(results, version); - if (uqe != null) { - result = result.unverify(); - } - return result; - } - - private static Result handleConjunction(List conjunctions, Version version) { + private static Result handleConjunction(List conjunctionsWithUnknowns, Version version) { + List conjunctions = conjunctionsWithUnknowns.stream().filter(r -> r.isUnknown() == false).collect(Collectors.toList()); if (conjunctions.isEmpty()) { - throw new IllegalArgumentException("Must have at least on conjunction sub result"); + if (conjunctionsWithUnknowns.isEmpty()) { + throw new IllegalArgumentException("Must have at least on conjunction sub result"); + } + return conjunctionsWithUnknowns.get(0); // all conjunctions are unknown, so just return the first one + } + if (conjunctionsWithUnknowns.size() == 1) { + return conjunctionsWithUnknowns.get(0); } if (version.onOrAfter(Version.V_6_1_0)) { for (Result subResult : conjunctions) { @@ -490,7 +261,7 @@ final class QueryAnalyzer { } } int msm = 0; - boolean verified = true; + boolean verified = conjunctionsWithUnknowns.size() == conjunctions.size(); boolean matchAllDocs = true; boolean hasDuplicateTerms = false; Set extractions = new HashSet<>(); @@ -523,21 +294,19 @@ final class QueryAnalyzer { resultMsm = 1; } else { resultMsm = 0; + verified = false; + break; } } - if (extractions.contains(queryExtraction)) { - - resultMsm = 0; + resultMsm = Math.max(0, resultMsm - 1); verified = false; - break; } } msm += resultMsm; - if (result.verified == false - // If some inner extractions are optional, the result can't be verified - || result.minimumShouldMatch < result.extractions.size()) { + // If some inner extractions are optional, the result can't be verified + || result.minimumShouldMatch < result.extractions.size()) { verified = false; } matchAllDocs &= result.matchAllDocs; @@ -548,6 +317,7 @@ final class QueryAnalyzer { } else { return new Result(verified, extractions, hasDuplicateTerms ? 1 : msm); } + } else { Result bestClause = null; for (Result result : conjunctions) { @@ -557,17 +327,13 @@ final class QueryAnalyzer { } } - private static Result handleDisjunctionQuery(List disjunctions, int requiredShouldClauses, Version version) { - List subResults = new ArrayList<>(); - for (Query query : disjunctions) { - // if either query fails extraction, we need to propagate as we could miss hits otherwise - Result subResult = analyze(query, version); - subResults.add(subResult); - } - return handleDisjunction(subResults, requiredShouldClauses, version); - } - private static Result handleDisjunction(List disjunctions, int requiredShouldClauses, Version version) { + if (disjunctions.stream().anyMatch(Result::isUnknown)) { + return Result.UNKNOWN; + } + if (disjunctions.size() == 1) { + return disjunctions.get(0); + } // Keep track of the msm for each clause: List clauses = new ArrayList<>(disjunctions.size()); boolean verified; @@ -764,7 +530,7 @@ final class QueryAnalyzer { final int minimumShouldMatch; final boolean matchAllDocs; - private Result(boolean matchAllDocs, boolean verified, Set extractions, int minimumShouldMatch) { + Result(boolean matchAllDocs, boolean verified, Set extractions, int minimumShouldMatch) { if (minimumShouldMatch > extractions.size()) { throw new IllegalArgumentException("minimumShouldMatch can't be greater than the number of extractions: " + minimumShouldMatch + " > " + extractions.size()); @@ -783,6 +549,11 @@ final class QueryAnalyzer { this(matchAllDocs, verified, Collections.emptySet(), 0); } + @Override + public String toString() { + return extractions.toString(); + } + Result unverify() { if (verified) { return new Result(matchAllDocs, false, extractions, minimumShouldMatch); @@ -791,9 +562,37 @@ final class QueryAnalyzer { } } + boolean isUnknown() { + return false; + } + boolean isMatchNoDocs() { return matchAllDocs == false && extractions.isEmpty(); } + + static final Result UNKNOWN = new Result(false, false, Collections.emptySet(), 0){ + @Override + boolean isUnknown() { + return true; + } + + @Override + boolean isMatchNoDocs() { + return false; + } + + @Override + public String toString() { + return "UNKNOWN"; + } + }; + + static final Result MATCH_NONE = new Result(false, true, Collections.emptySet(), 0) { + @Override + boolean isMatchNoDocs() { + return true; + } + }; } static class QueryExtraction { @@ -846,26 +645,6 @@ final class QueryAnalyzer { } } - /** - * Exception indicating that none or some query terms couldn't extracted from a percolator query. - */ - static class UnsupportedQueryException extends RuntimeException { - - private final Query unsupportedQuery; - - UnsupportedQueryException(Query unsupportedQuery) { - super(LoggerMessageFormat.format("no query terms can be extracted from query [{}]", unsupportedQuery)); - this.unsupportedQuery = unsupportedQuery; - } - - /** - * The actual Lucene query that was unsupported and caused this exception to be thrown. - */ - Query getUnsupportedQuery() { - return unsupportedQuery; - } - } - static class Range { final String fieldName; diff --git a/modules/percolator/src/test/java/org/elasticsearch/percolator/PercolatorFieldMapperTests.java b/modules/percolator/src/test/java/org/elasticsearch/percolator/PercolatorFieldMapperTests.java index 8db8a549c1e..e68db3872ae 100644 --- a/modules/percolator/src/test/java/org/elasticsearch/percolator/PercolatorFieldMapperTests.java +++ b/modules/percolator/src/test/java/org/elasticsearch/percolator/PercolatorFieldMapperTests.java @@ -892,7 +892,7 @@ public class PercolatorFieldMapperTests extends ESSingleNodeTestCase { assertThat(values.get(1), equalTo("field\0value2")); assertThat(values.get(2), equalTo("field\0value3")); int msm = doc.rootDoc().getFields(fieldType.minimumShouldMatchField.name())[0].numericValue().intValue(); - assertThat(msm, equalTo(2)); + assertThat(msm, equalTo(3)); qb = boolQuery() .must(boolQuery().must(termQuery("field", "value1")).must(termQuery("field", "value2"))) @@ -916,7 +916,7 @@ public class PercolatorFieldMapperTests extends ESSingleNodeTestCase { assertThat(values.get(3), equalTo("field\0value4")); assertThat(values.get(4), equalTo("field\0value5")); msm = doc.rootDoc().getFields(fieldType.minimumShouldMatchField.name())[0].numericValue().intValue(); - assertThat(msm, equalTo(2)); + assertThat(msm, equalTo(4)); qb = boolQuery() .minimumShouldMatch(3) diff --git a/modules/percolator/src/test/java/org/elasticsearch/percolator/QueryAnalyzerTests.java b/modules/percolator/src/test/java/org/elasticsearch/percolator/QueryAnalyzerTests.java index 712d5688827..b625d91cc64 100644 --- a/modules/percolator/src/test/java/org/elasticsearch/percolator/QueryAnalyzerTests.java +++ b/modules/percolator/src/test/java/org/elasticsearch/percolator/QueryAnalyzerTests.java @@ -29,6 +29,10 @@ import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.index.Term; import org.apache.lucene.queries.BlendedTermQuery; import org.apache.lucene.queries.CommonTermsQuery; +import org.apache.lucene.queries.XIntervals; +import org.apache.lucene.queries.intervals.IntervalQuery; +import org.apache.lucene.queries.intervals.Intervals; +import org.apache.lucene.queries.intervals.IntervalsSource; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; @@ -73,7 +77,6 @@ import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; -import static org.elasticsearch.percolator.QueryAnalyzer.UnsupportedQueryException; import static org.elasticsearch.percolator.QueryAnalyzer.analyze; import static org.elasticsearch.percolator.QueryAnalyzer.selectBestResult; import static org.hamcrest.Matchers.equalTo; @@ -450,7 +453,7 @@ public class QueryAnalyzerTests extends ESTestCase { builder.add(termQuery2, BooleanClause.Occur.SHOULD); builder.add(termQuery3, BooleanClause.Occur.SHOULD); result = analyze(builder.build(), Version.CURRENT); - assertThat("Minimum match has not impact on whether the result is verified", result.verified, is(true)); + assertThat("Minimum match has no impact on whether the result is verified", result.verified, is(true)); assertThat("msm is at least two so result.minimumShouldMatch should 2 too", result.minimumShouldMatch, equalTo(msm)); builder = new BooleanQuery.Builder(); @@ -827,18 +830,14 @@ public class QueryAnalyzerTests extends ESTestCase { public void testExtractQueryMetadata_unsupportedQuery() { TermRangeQuery termRangeQuery = new TermRangeQuery("_field", null, null, true, false); - UnsupportedQueryException e = expectThrows(UnsupportedQueryException.class, - () -> analyze(termRangeQuery, Version.CURRENT)); - assertThat(e.getUnsupportedQuery(), sameInstance(termRangeQuery)); + assertEquals(Result.UNKNOWN, analyze(termRangeQuery, Version.CURRENT)); TermQuery termQuery1 = new TermQuery(new Term("_field", "_term")); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(termQuery1, BooleanClause.Occur.SHOULD); builder.add(termRangeQuery, BooleanClause.Occur.SHOULD); BooleanQuery bq = builder.build(); - - e = expectThrows(UnsupportedQueryException.class, () -> analyze(bq, Version.CURRENT)); - assertThat(e.getUnsupportedQuery(), sameInstance(termRangeQuery)); + assertEquals(Result.UNKNOWN, analyze(bq, Version.CURRENT)); } public void testExtractQueryMetadata_unsupportedQueryInBoolQueryWithMustClauses() { @@ -870,8 +869,7 @@ public class QueryAnalyzerTests extends ESTestCase { builder.add(unsupportedQuery, BooleanClause.Occur.MUST); builder.add(unsupportedQuery, BooleanClause.Occur.MUST); BooleanQuery bq2 = builder.build(); - UnsupportedQueryException e = expectThrows(UnsupportedQueryException.class, () -> analyze(bq2, Version.CURRENT)); - assertThat(e.getUnsupportedQuery(), sameInstance(unsupportedQuery)); + assertEquals(Result.UNKNOWN, analyze(bq2, Version.CURRENT)); } public void testExtractQueryMetadata_disjunctionMaxQuery() { @@ -1173,10 +1171,10 @@ public class QueryAnalyzerTests extends ESTestCase { public void testTooManyPointDimensions() { // For now no extraction support for geo queries: Query query1 = LatLonPoint.newBoxQuery("_field", 0, 1, 0, 1); - expectThrows(UnsupportedQueryException.class, () -> analyze(query1, Version.CURRENT)); + assertEquals(Result.UNKNOWN, analyze(query1, Version.CURRENT)); Query query2 = LongPoint.newRangeQuery("_field", new long[]{0, 0, 0}, new long[]{1, 1, 1}); - expectThrows(UnsupportedQueryException.class, () -> analyze(query2, Version.CURRENT)); + assertEquals(Result.UNKNOWN, analyze(query2, Version.CURRENT)); } public void testPointRangeQuery_lowerUpperReversed() { @@ -1338,7 +1336,7 @@ public class QueryAnalyzerTests extends ESTestCase { Result result = analyze(builder.build(), Version.CURRENT); assertThat(result.verified, is(false)); assertThat(result.matchAllDocs, is(false)); - assertThat(result.minimumShouldMatch, equalTo(2)); + assertThat(result.minimumShouldMatch, equalTo(4)); assertTermsEqual(result.extractions, new Term("field", "value1"), new Term("field", "value2"), new Term("field", "value3"), new Term("field", "value4")); @@ -1375,16 +1373,10 @@ public class QueryAnalyzerTests extends ESTestCase { public void testEmptyQueries() { BooleanQuery.Builder builder = new BooleanQuery.Builder(); Result result = analyze(builder.build(), Version.CURRENT); - assertThat(result.verified, is(false)); - assertThat(result.matchAllDocs, is(false)); - assertThat(result.minimumShouldMatch, equalTo(0)); - assertThat(result.extractions.size(), equalTo(0)); + assertEquals(result, Result.MATCH_NONE); result = analyze(new DisjunctionMaxQuery(Collections.emptyList(), 0f), Version.CURRENT); - assertThat(result.verified, is(false)); - assertThat(result.matchAllDocs, is(false)); - assertThat(result.minimumShouldMatch, equalTo(0)); - assertThat(result.extractions.size(), equalTo(0)); + assertEquals(result, Result.MATCH_NONE); } private static void assertDimension(byte[] expected, Consumer consumer) { @@ -1410,4 +1402,108 @@ public class QueryAnalyzerTests extends ESTestCase { return queryExtractions; } + public void testIntervalQueries() { + IntervalsSource source = Intervals.or(Intervals.term("term1"), Intervals.term("term2")); + Result result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "term1"), new Term("field", "term2")); + + source = Intervals.ordered(Intervals.term("term1"), Intervals.term("term2"), + Intervals.or(Intervals.term("term3"), Intervals.term("term4"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(3)); + assertTermsEqual(result.extractions, new Term("field", "term1"), new Term("field", "term2"), + new Term("field", "term3"), new Term("field", "term4")); + + source = Intervals.ordered(Intervals.term("term1"), XIntervals.wildcard(new BytesRef("a*"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "term1")); + + source = Intervals.ordered(XIntervals.wildcard(new BytesRef("a*"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertEquals(Result.UNKNOWN, result); + + source = Intervals.or(Intervals.term("b"), XIntervals.wildcard(new BytesRef("a*"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertEquals(Result.UNKNOWN, result); + + source = Intervals.ordered(Intervals.term("term1"), XIntervals.prefix(new BytesRef("a"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "term1")); + + source = Intervals.ordered(XIntervals.prefix(new BytesRef("a"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertEquals(Result.UNKNOWN, result); + + source = Intervals.or(Intervals.term("b"), XIntervals.prefix(new BytesRef("a"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertEquals(Result.UNKNOWN, result); + + source = Intervals.containedBy(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(3)); + assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c")); + + source = Intervals.containing(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(3)); + assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c")); + + source = Intervals.overlapping(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(3)); + assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c")); + + source = Intervals.within(Intervals.term("a"), 2, Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(3)); + assertTermsEqual(result.extractions, new Term("field", "a"), new Term("field", "b"), new Term("field", "c")); + + source = Intervals.notContainedBy(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "a")); + + source = Intervals.notContaining(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "a")); + + source = Intervals.nonOverlapping(Intervals.term("a"), Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "a")); + + source = Intervals.notWithin(Intervals.term("a"), 2, Intervals.ordered(Intervals.term("b"), Intervals.term("c"))); + result = analyze(new IntervalQuery("field", source), Version.CURRENT); + assertThat(result.verified, is(false)); + assertThat(result.matchAllDocs, is(false)); + assertThat(result.minimumShouldMatch, equalTo(1)); + assertTermsEqual(result.extractions, new Term("field", "a")); + } + } diff --git a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java index 1c775b01bee..30fd9df9dee 100644 --- a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java +++ b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java @@ -30,6 +30,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.InPlaceMergeSorter; @@ -39,6 +40,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; /** * BlendedTermQuery can be used to unify term statistics across @@ -245,6 +248,17 @@ public abstract class BlendedTermQuery extends Query { return builder.toString(); } + @Override + public void visit(QueryVisitor visitor) { + Set fields = Arrays.stream(terms).map(Term::field).collect(Collectors.toSet()); + for (String field : fields) { + if (visitor.acceptField(field) == false) { + return; + } + } + visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this).consumeTerms(this, terms); + } + private class TermAndBoost implements Comparable { protected final Term term; protected float boost; diff --git a/server/src/main/java/org/apache/lucene/queries/XIntervals.java b/server/src/main/java/org/apache/lucene/queries/XIntervals.java new file mode 100644 index 00000000000..08ea7d4e0da --- /dev/null +++ b/server/src/main/java/org/apache/lucene/queries/XIntervals.java @@ -0,0 +1,797 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.lucene.queries; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.queries.intervals.IntervalIterator; +import org.apache.lucene.queries.intervals.IntervalQuery; +import org.apache.lucene.queries.intervals.Intervals; +import org.apache.lucene.queries.intervals.IntervalsSource; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.MatchesUtils; +import org.apache.lucene.search.PrefixQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.WildcardQuery; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.automaton.CompiledAutomaton; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * Replacement for {@link Intervals#wildcard(BytesRef)} and {@link Intervals#prefix(BytesRef)} + * until LUCENE-9050 is merged + */ +public final class XIntervals { + + private XIntervals() {} + + public static IntervalsSource wildcard(BytesRef wildcard) { + CompiledAutomaton ca = new CompiledAutomaton(WildcardQuery.toAutomaton(new Term("", wildcard))); + return new MultiTermIntervalsSource(ca, 128, wildcard.utf8ToString()); + } + + public static IntervalsSource prefix(BytesRef prefix) { + CompiledAutomaton ca = new CompiledAutomaton(PrefixQuery.toAutomaton(prefix)); + return new MultiTermIntervalsSource(ca, 128, prefix.utf8ToString()); + } + + static class MultiTermIntervalsSource extends IntervalsSource { + + private final CompiledAutomaton automaton; + private final int maxExpansions; + private final String pattern; + + MultiTermIntervalsSource(CompiledAutomaton automaton, int maxExpansions, String pattern) { + this.automaton = automaton; + if (maxExpansions > BooleanQuery.getMaxClauseCount()) { + throw new IllegalArgumentException("maxExpansions [" + maxExpansions + + "] cannot be greater than BooleanQuery.getMaxClauseCount [" + BooleanQuery.getMaxClauseCount() + "]"); + } + this.maxExpansions = maxExpansions; + this.pattern = pattern; + } + + @Override + public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException { + Terms terms = ctx.reader().terms(field); + if (terms == null) { + return null; + } + List subSources = new ArrayList<>(); + TermsEnum te = automaton.getTermsEnum(terms); + BytesRef term; + int count = 0; + while ((term = te.next()) != null) { + subSources.add(TermIntervalsSource.intervals(term, te)); + if (++count > maxExpansions) { + throw new IllegalStateException("Automaton [" + this.pattern + "] expanded to too many terms (limit " + + maxExpansions + ")"); + } + } + if (subSources.size() == 0) { + return null; + } + return new DisjunctionIntervalIterator(subSources); + } + + @Override + public MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException { + Terms terms = ctx.reader().terms(field); + if (terms == null) { + return null; + } + List subMatches = new ArrayList<>(); + TermsEnum te = automaton.getTermsEnum(terms); + BytesRef term; + int count = 0; + while ((term = te.next()) != null) { + MatchesIterator mi = XIntervals.TermIntervalsSource.matches(te, doc); + if (mi != null) { + subMatches.add(mi); + if (count++ > maxExpansions) { + throw new IllegalStateException("Automaton " + term + " expanded to too many terms (limit " + maxExpansions + ")"); + } + } + } + return MatchesUtils.disjunction(subMatches); + } + + @Override + public void visit(String field, QueryVisitor visitor) { + visitor.visitLeaf(new IntervalQuery(field, this)); + } + + @Override + public int minExtent() { + return 1; + } + + @Override + public Collection pullUpDisjunctions() { + return Collections.singleton(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MultiTermIntervalsSource that = (MultiTermIntervalsSource) o; + return maxExpansions == that.maxExpansions && + Objects.equals(automaton, that.automaton) && + Objects.equals(pattern, that.pattern); + } + + @Override + public int hashCode() { + return Objects.hash(automaton, maxExpansions, pattern); + } + + @Override + public String toString() { + return "MultiTerm(" + pattern + ")"; + } + } + + static class DisiWrapper { + + public final DocIdSetIterator iterator; + public final IntervalIterator intervals; + public final long cost; + public final float matchCost; // the match cost for two-phase iterators, 0 otherwise + public int doc; // the current doc, used for comparison + public DisiWrapper next; // reference to a next element, see #topList + + // An approximation of the iterator, or the iterator itself if it does not + // support two-phase iteration + public final DocIdSetIterator approximation; + + DisiWrapper(IntervalIterator iterator) { + this.intervals = iterator; + this.iterator = iterator; + this.cost = iterator.cost(); + this.doc = -1; + this.approximation = iterator; + this.matchCost = iterator.matchCost(); + } + + } + + static final class DisiPriorityQueue implements Iterable { + + static int leftNode(int node) { + return ((node + 1) << 1) - 1; + } + + static int rightNode(int leftNode) { + return leftNode + 1; + } + + static int parentNode(int node) { + return ((node + 1) >>> 1) - 1; + } + + private final DisiWrapper[] heap; + private int size; + + DisiPriorityQueue(int maxSize) { + heap = new DisiWrapper[maxSize]; + size = 0; + } + + public int size() { + return size; + } + + public DisiWrapper top() { + return heap[0]; + } + + /** Get the list of scorers which are on the current doc. */ + DisiWrapper topList() { + final DisiWrapper[] heap = this.heap; + final int size = this.size; + DisiWrapper list = heap[0]; + list.next = null; + if (size >= 3) { + list = topList(list, heap, size, 1); + list = topList(list, heap, size, 2); + } else if (size == 2 && heap[1].doc == list.doc) { + list = prepend(heap[1], list); + } + return list; + } + + // prepend w1 (iterator) to w2 (list) + private DisiWrapper prepend(DisiWrapper w1, DisiWrapper w2) { + w1.next = w2; + return w1; + } + + private DisiWrapper topList(DisiWrapper list, DisiWrapper[] heap, + int size, int i) { + final DisiWrapper w = heap[i]; + if (w.doc == list.doc) { + list = prepend(w, list); + final int left = leftNode(i); + final int right = left + 1; + if (right < size) { + list = topList(list, heap, size, left); + list = topList(list, heap, size, right); + } else if (left < size && heap[left].doc == list.doc) { + list = prepend(heap[left], list); + } + } + return list; + } + + public DisiWrapper add(DisiWrapper entry) { + final DisiWrapper[] heap = this.heap; + final int size = this.size; + heap[size] = entry; + upHeap(size); + this.size = size + 1; + return heap[0]; + } + + public DisiWrapper pop() { + final DisiWrapper[] heap = this.heap; + final DisiWrapper result = heap[0]; + final int i = --size; + heap[0] = heap[i]; + heap[i] = null; + downHeap(i); + return result; + } + + DisiWrapper updateTop() { + downHeap(size); + return heap[0]; + } + + void upHeap(int i) { + final DisiWrapper node = heap[i]; + final int nodeDoc = node.doc; + int j = parentNode(i); + while (j >= 0 && nodeDoc < heap[j].doc) { + heap[i] = heap[j]; + i = j; + j = parentNode(j); + } + heap[i] = node; + } + + void downHeap(int size) { + int i = 0; + final DisiWrapper node = heap[0]; + int j = leftNode(i); + if (j < size) { + int k = rightNode(j); + if (k < size && heap[k].doc < heap[j].doc) { + j = k; + } + if (heap[j].doc < node.doc) { + do { + heap[i] = heap[j]; + i = j; + j = leftNode(i); + k = rightNode(j); + if (k < size && heap[k].doc < heap[j].doc) { + j = k; + } + } while (j < size && heap[j].doc < node.doc); + heap[i] = node; + } + } + } + + @Override + public Iterator iterator() { + return Arrays.asList(heap).subList(0, size).iterator(); + } + + } + + static class DisjunctionDISIApproximation extends DocIdSetIterator { + + final DisiPriorityQueue subIterators; + final long cost; + + DisjunctionDISIApproximation(DisiPriorityQueue subIterators) { + this.subIterators = subIterators; + long cost = 0; + for (DisiWrapper w : subIterators) { + cost += w.cost; + } + this.cost = cost; + } + + @Override + public long cost() { + return cost; + } + + @Override + public int docID() { + return subIterators.top().doc; + } + + @Override + public int nextDoc() throws IOException { + DisiWrapper top = subIterators.top(); + final int doc = top.doc; + do { + top.doc = top.approximation.nextDoc(); + top = subIterators.updateTop(); + } while (top.doc == doc); + + return top.doc; + } + + @Override + public int advance(int target) throws IOException { + DisiWrapper top = subIterators.top(); + do { + top.doc = top.approximation.advance(target); + top = subIterators.updateTop(); + } while (top.doc < target); + + return top.doc; + } + } + + static class DisjunctionIntervalIterator extends IntervalIterator { + + final DocIdSetIterator approximation; + final PriorityQueue intervalQueue; + final DisiPriorityQueue disiQueue; + final List iterators; + final float matchCost; + + IntervalIterator current = EMPTY; + + DisjunctionIntervalIterator(List iterators) { + this.disiQueue = new DisiPriorityQueue(iterators.size()); + for (IntervalIterator it : iterators) { + disiQueue.add(new DisiWrapper(it)); + } + this.approximation = new DisjunctionDISIApproximation(disiQueue); + this.iterators = iterators; + this.intervalQueue = new PriorityQueue(iterators.size()) { + @Override + protected boolean lessThan(IntervalIterator a, IntervalIterator b) { + return a.end() < b.end() || (a.end() == b.end() && a.start() >= b.start()); + } + }; + float costsum = 0; + for (IntervalIterator it : iterators) { + costsum += it.cost(); + } + this.matchCost = costsum; + } + + @Override + public float matchCost() { + return matchCost; + } + + @Override + public int start() { + return current.start(); + } + + @Override + public int end() { + return current.end(); + } + + @Override + public int gaps() { + return current.gaps(); + } + + private void reset() throws IOException { + intervalQueue.clear(); + for (DisiWrapper dw = disiQueue.topList(); dw != null; dw = dw.next) { + dw.intervals.nextInterval(); + intervalQueue.add(dw.intervals); + } + current = EMPTY; + } + + @Override + public int nextInterval() throws IOException { + if (current == EMPTY || current == EXHAUSTED) { + if (intervalQueue.size() > 0) { + current = intervalQueue.top(); + } + return current.start(); + } + int start = current.start(), end = current.end(); + while (intervalQueue.size() > 0 && contains(intervalQueue.top(), start, end)) { + IntervalIterator it = intervalQueue.pop(); + if (it != null && it.nextInterval() != NO_MORE_INTERVALS) { + intervalQueue.add(it); + } + } + if (intervalQueue.size() == 0) { + current = EXHAUSTED; + return NO_MORE_INTERVALS; + } + current = intervalQueue.top(); + return current.start(); + } + + private boolean contains(IntervalIterator it, int start, int end) { + return start >= it.start() && start <= it.end() && end >= it.start() && end <= it.end(); + } + + @Override + public int docID() { + return approximation.docID(); + } + + @Override + public int nextDoc() throws IOException { + int doc = approximation.nextDoc(); + reset(); + return doc; + } + + @Override + public int advance(int target) throws IOException { + int doc = approximation.advance(target); + reset(); + return doc; + } + + @Override + public long cost() { + return approximation.cost(); + } + } + + private static final IntervalIterator EMPTY = new IntervalIterator() { + + @Override + public int docID() { + throw new UnsupportedOperationException(); + } + + @Override + public int nextDoc() { + throw new UnsupportedOperationException(); + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + throw new UnsupportedOperationException(); + } + + @Override + public int start() { + return -1; + } + + @Override + public int end() { + return -1; + } + + @Override + public int gaps() { + throw new UnsupportedOperationException(); + } + + @Override + public int nextInterval() { + return NO_MORE_INTERVALS; + } + + @Override + public float matchCost() { + return 0; + } + }; + + private static final IntervalIterator EXHAUSTED = new IntervalIterator() { + + @Override + public int docID() { + throw new UnsupportedOperationException(); + } + + @Override + public int nextDoc() { + throw new UnsupportedOperationException(); + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + throw new UnsupportedOperationException(); + } + + @Override + public int start() { + return NO_MORE_INTERVALS; + } + + @Override + public int end() { + return NO_MORE_INTERVALS; + } + + @Override + public int gaps() { + throw new UnsupportedOperationException(); + } + + @Override + public int nextInterval() { + return NO_MORE_INTERVALS; + } + + @Override + public float matchCost() { + return 0; + } + }; + + static class TermIntervalsSource extends IntervalsSource { + + final BytesRef term; + + TermIntervalsSource(BytesRef term) { + this.term = term; + } + + @Override + public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException { + Terms terms = ctx.reader().terms(field); + if (terms == null) + return null; + if (terms.hasPositions() == false) { + throw new IllegalArgumentException("Cannot create an IntervalIterator over field " + field + + " because it has no indexed positions"); + } + TermsEnum te = terms.iterator(); + if (te.seekExact(term) == false) { + return null; + } + return intervals(term, te); + } + + static IntervalIterator intervals(BytesRef term, TermsEnum te) throws IOException { + PostingsEnum pe = te.postings(null, PostingsEnum.POSITIONS); + float cost = termPositionsCost(te); + return new IntervalIterator() { + + @Override + public int docID() { + return pe.docID(); + } + + @Override + public int nextDoc() throws IOException { + int doc = pe.nextDoc(); + reset(); + return doc; + } + + @Override + public int advance(int target) throws IOException { + int doc = pe.advance(target); + reset(); + return doc; + } + + @Override + public long cost() { + return pe.cost(); + } + + int pos = -1, upto; + + @Override + public int start() { + return pos; + } + + @Override + public int end() { + return pos; + } + + @Override + public int gaps() { + return 0; + } + + @Override + public int nextInterval() throws IOException { + if (upto <= 0) + return pos = NO_MORE_INTERVALS; + upto--; + return pos = pe.nextPosition(); + } + + @Override + public float matchCost() { + return cost; + } + + private void reset() throws IOException { + if (pe.docID() == NO_MORE_DOCS) { + upto = -1; + pos = NO_MORE_INTERVALS; + } + else { + upto = pe.freq(); + pos = -1; + } + } + + @Override + public String toString() { + return term.utf8ToString() + ":" + super.toString(); + } + }; + } + + @Override + public MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException { + Terms terms = ctx.reader().terms(field); + if (terms == null) + return null; + if (terms.hasPositions() == false) { + throw new IllegalArgumentException("Cannot create an IntervalIterator over field " + field + + " because it has no indexed positions"); + } + TermsEnum te = terms.iterator(); + if (te.seekExact(term) == false) { + return null; + } + return matches(te, doc); + } + + static MatchesIterator matches(TermsEnum te, int doc) throws IOException { + PostingsEnum pe = te.postings(null, PostingsEnum.OFFSETS); + if (pe.advance(doc) != doc) { + return null; + } + return new MatchesIterator() { + + int upto = pe.freq(); + int pos = -1; + + @Override + public boolean next() throws IOException { + if (upto <= 0) { + pos = IntervalIterator.NO_MORE_INTERVALS; + return false; + } + upto--; + pos = pe.nextPosition(); + return true; + } + + @Override + public int startPosition() { + return pos; + } + + @Override + public int endPosition() { + return pos; + } + + @Override + public int startOffset() throws IOException { + return pe.startOffset(); + } + + @Override + public int endOffset() throws IOException { + return pe.endOffset(); + } + + @Override + public MatchesIterator getSubMatches() { + return null; + } + + @Override + public Query getQuery() { + throw new UnsupportedOperationException(); + } + }; + } + + @Override + public int minExtent() { + return 1; + } + + @Override + public Collection pullUpDisjunctions() { + return Collections.singleton(this); + } + + @Override + public int hashCode() { + return Objects.hash(term); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TermIntervalsSource that = (TermIntervalsSource) o; + return Objects.equals(term, that.term); + } + + @Override + public String toString() { + return term.utf8ToString(); + } + + @Override + public void visit(String field, QueryVisitor visitor) { + visitor.consumeTerms(new IntervalQuery(field, this), new Term(field, term)); + } + + private static final int TERM_POSNS_SEEK_OPS_PER_DOC = 128; + + private static final int TERM_OPS_PER_POS = 7; + + static float termPositionsCost(TermsEnum termsEnum) throws IOException { + int docFreq = termsEnum.docFreq(); + assert docFreq > 0; + long totalTermFreq = termsEnum.totalTermFreq(); + float expOccurrencesInMatchingDoc = totalTermFreq / (float) docFreq; + return TERM_POSNS_SEEK_OPS_PER_DOC + expOccurrencesInMatchingDoc * TERM_OPS_PER_POS; + } + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java index 3ed689d4878..4c73dbe340c 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java @@ -33,6 +33,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; +import org.apache.lucene.queries.XIntervals; import org.apache.lucene.queries.intervals.Intervals; import org.apache.lucene.queries.intervals.IntervalsSource; import org.apache.lucene.search.AutomatonQuery; @@ -426,7 +427,7 @@ public class TextFieldMapper extends FieldMapper { public IntervalsSource intervals(BytesRef term) { if (term.length > maxChars) { - return Intervals.prefix(term); + return XIntervals.prefix(term); } if (term.length >= minChars) { return Intervals.fixField(name(), Intervals.term(term)); @@ -436,7 +437,7 @@ public class TextFieldMapper extends FieldMapper { sb.append("?"); } String wildcardTerm = sb.toString(); - return Intervals.or(Intervals.fixField(name(), Intervals.wildcard(new BytesRef(wildcardTerm))), Intervals.term(term)); + return Intervals.or(Intervals.fixField(name(), XIntervals.wildcard(new BytesRef(wildcardTerm))), Intervals.term(term)); } @Override @@ -680,7 +681,7 @@ public class TextFieldMapper extends FieldMapper { if (prefixFieldType != null) { return prefixFieldType.intervals(normalizedTerm); } - return Intervals.prefix(normalizedTerm); + return XIntervals.prefix(normalizedTerm); } IntervalBuilder builder = new IntervalBuilder(name(), analyzer == null ? searchAnalyzer() : analyzer); return builder.analyzeText(text, maxGaps, ordered); diff --git a/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java b/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java index b81206c7f87..81cc1524549 100644 --- a/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java +++ b/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.query; import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.queries.XIntervals; import org.apache.lucene.queries.intervals.FilteredIntervalsSource; import org.apache.lucene.queries.intervals.IntervalIterator; import org.apache.lucene.queries.intervals.Intervals; @@ -585,12 +586,12 @@ public abstract class IntervalsSourceProvider implements NamedWriteable, ToXCont } BytesRef normalizedTerm = analyzer.normalize(useField, pattern); // TODO Intervals.wildcard() should take BytesRef - source = Intervals.fixField(useField, Intervals.wildcard(normalizedTerm)); + source = Intervals.fixField(useField, XIntervals.wildcard(normalizedTerm)); } else { checkPositions(fieldType); BytesRef normalizedTerm = analyzer.normalize(fieldType.name(), pattern); - source = Intervals.wildcard(normalizedTerm); + source = XIntervals.wildcard(normalizedTerm); } return source; } diff --git a/server/src/test/java/org/elasticsearch/index/query/IntervalQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/IntervalQueryBuilderTests.java index 4f2d9d217f9..da1da5ce54b 100644 --- a/server/src/test/java/org/elasticsearch/index/query/IntervalQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/IntervalQueryBuilderTests.java @@ -19,11 +19,12 @@ package org.elasticsearch.index.query; +import org.apache.lucene.queries.XIntervals; +import org.apache.lucene.queries.intervals.IntervalQuery; +import org.apache.lucene.queries.intervals.Intervals; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.queries.intervals.IntervalQuery; -import org.apache.lucene.queries.intervals.Intervals; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; @@ -395,7 +396,7 @@ public class IntervalQueryBuilderTests extends AbstractQueryTestCase