diff --git a/modules/percolator/src/main/java/org/elasticsearch/percolator/ExtractQueryTermsService.java b/modules/percolator/src/main/java/org/elasticsearch/percolator/ExtractQueryTermsService.java index f9f1aa575b7..c451a245754 100644 --- a/modules/percolator/src/main/java/org/elasticsearch/percolator/ExtractQueryTermsService.java +++ b/modules/percolator/src/main/java/org/elasticsearch/percolator/ExtractQueryTermsService.java @@ -34,6 +34,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; @@ -192,6 +193,13 @@ public final class ExtractQueryTermsService { } else if (query instanceof BlendedTermQuery) { List terms = ((BlendedTermQuery) query).getTerms(); return new HashSet<>(terms); + } else if (query instanceof DisjunctionMaxQuery) { + List disjuncts = ((DisjunctionMaxQuery) query).getDisjuncts(); + Set terms = new HashSet<>(); + for (Query disjunct : disjuncts) { + terms.addAll(extractQueryTerms(disjunct)); + } + return terms; } else if (query instanceof SpanTermQuery) { return Collections.singleton(((SpanTermQuery) query).getTerm()); } else if (query instanceof SpanNearQuery) { diff --git a/modules/percolator/src/test/java/org/elasticsearch/percolator/ExtractQueryTermsServiceTests.java b/modules/percolator/src/test/java/org/elasticsearch/percolator/ExtractQueryTermsServiceTests.java index 444e47d90ce..b9430a32425 100644 --- a/modules/percolator/src/test/java/org/elasticsearch/percolator/ExtractQueryTermsServiceTests.java +++ b/modules/percolator/src/test/java/org/elasticsearch/percolator/ExtractQueryTermsServiceTests.java @@ -32,6 +32,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermRangeQuery; @@ -362,6 +363,28 @@ public class ExtractQueryTermsServiceTests extends ESTestCase { assertThat(e.getUnsupportedQuery(), sameInstance(unsupportedQuery)); } + public void testExtractQueryMetadata_disjunctionMaxQuery() { + TermQuery termQuery1 = new TermQuery(new Term("_field", "_term1")); + TermQuery termQuery2 = new TermQuery(new Term("_field", "_term2")); + TermQuery termQuery3 = new TermQuery(new Term("_field", "_term3")); + TermQuery termQuery4 = new TermQuery(new Term("_field", "_term4")); + DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery( + Arrays.asList(termQuery1, termQuery2, termQuery3, termQuery4), 0.1f + ); + + List terms = new ArrayList<>(extractQueryTerms(disjunctionMaxQuery)); + Collections.sort(terms); + assertThat(terms.size(), equalTo(4)); + assertThat(terms.get(0).field(), equalTo(termQuery1.getTerm().field())); + assertThat(terms.get(0).bytes(), equalTo(termQuery1.getTerm().bytes())); + assertThat(terms.get(1).field(), equalTo(termQuery2.getTerm().field())); + assertThat(terms.get(1).bytes(), equalTo(termQuery2.getTerm().bytes())); + assertThat(terms.get(2).field(), equalTo(termQuery3.getTerm().field())); + assertThat(terms.get(2).bytes(), equalTo(termQuery3.getTerm().bytes())); + assertThat(terms.get(3).field(), equalTo(termQuery4.getTerm().field())); + assertThat(terms.get(3).bytes(), equalTo(termQuery4.getTerm().bytes())); + } + public void testCreateQueryMetadataQuery() throws Exception { MemoryIndex memoryIndex = new MemoryIndex(false); memoryIndex.addField("field1", "the quick brown fox jumps over the lazy dog", new WhitespaceAnalyzer());