Query DSL: Fix `minimum should match` in `simple_query_string` for single term and multiple fields

Currently a `simple_query_string` query with one term and multiple fields
gets parsed to a BooleanQuery where the number of clauses is determined
by the number of fields, which lead to wrong calculation of `minimum_should_match`.

This PR adds checks to detect this case and wrap the resulting BooleanQuery into
another BooleanQuery with just one should-clause, so `minimum_should_match`
calculation is corrected.

In order to differentiate between the case where one term is queried across
multiple fields and the case where multiple terms are queried on one field,
we override a simplification step in Lucenes SimpleQueryParser that reduces
a one-clause BooleanQuery to the clause itself.

Closes #13884
This commit is contained in:
Christoph Büscher 2015-10-16 12:44:49 +02:00
parent df4c7c7aee
commit fd3a46a1a5
5 changed files with 137 additions and 48 deletions

View File

@ -70,7 +70,7 @@ public class SimpleQueryParser extends org.apache.lucene.queryparser.simple.Simp
rethrowUnlessLenient(e); rethrowUnlessLenient(e);
} }
} }
return super.simplify(bq.build()); return simplify(bq.build());
} }
/** /**
@ -93,7 +93,7 @@ public class SimpleQueryParser extends org.apache.lucene.queryparser.simple.Simp
rethrowUnlessLenient(e); rethrowUnlessLenient(e);
} }
} }
return super.simplify(bq.build()); return simplify(bq.build());
} }
@Override @Override
@ -111,7 +111,7 @@ public class SimpleQueryParser extends org.apache.lucene.queryparser.simple.Simp
rethrowUnlessLenient(e); rethrowUnlessLenient(e);
} }
} }
return super.simplify(bq.build()); return simplify(bq.build());
} }
/** /**
@ -140,7 +140,19 @@ public class SimpleQueryParser extends org.apache.lucene.queryparser.simple.Simp
return rethrowUnlessLenient(e); return rethrowUnlessLenient(e);
} }
} }
return super.simplify(bq.build()); return simplify(bq.build());
}
/**
* Override of lucenes SimpleQueryParser that doesn't simplify for the 1-clause case.
*/
@Override
protected Query simplify(BooleanQuery bq) {
if (bq.clauses().isEmpty()) {
return null;
} else {
return bq;
}
} }
/** /**

View File

@ -20,6 +20,8 @@
package org.elasticsearch.index.query; package org.elasticsearch.index.query;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
@ -286,7 +288,16 @@ public class SimpleQueryStringBuilder extends AbstractQueryBuilder<SimpleQuerySt
Query query = sqp.parse(queryText); Query query = sqp.parse(queryText);
if (minimumShouldMatch != null && query instanceof BooleanQuery) { if (minimumShouldMatch != null && query instanceof BooleanQuery) {
query = Queries.applyMinimumShouldMatch((BooleanQuery) query, minimumShouldMatch); BooleanQuery booleanQuery = (BooleanQuery) query;
// treat special case for one term query and more than one field
// we need to wrap this in additional BooleanQuery so minimum_should_match is applied correctly
if (booleanQuery.clauses().size() > 1
&& ((booleanQuery.clauses().iterator().next().getQuery() instanceof BooleanQuery) == false)) {
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(new BooleanClause(booleanQuery, Occur.SHOULD));
booleanQuery = builder.build();
}
query = Queries.applyMinimumShouldMatch(booleanQuery, minimumShouldMatch);
} }
return query; return query;
} }

View File

@ -456,7 +456,7 @@ public abstract class AbstractQueryTestCase<QB extends AbstractQueryBuilder<QB>>
testQuery.writeTo(output); testQuery.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) {
QueryBuilder<?> prototype = queryParser(testQuery.getName()).getBuilderPrototype(); QueryBuilder<?> prototype = queryParser(testQuery.getName()).getBuilderPrototype();
QueryBuilder deserializedQuery = prototype.readFrom(in); QueryBuilder<?> deserializedQuery = prototype.readFrom(in);
assertEquals(deserializedQuery, testQuery); assertEquals(deserializedQuery, testQuery);
assertEquals(deserializedQuery.hashCode(), testQuery.hashCode()); assertEquals(deserializedQuery.hashCode(), testQuery.hashCode());
assertNotSame(deserializedQuery, testQuery); assertNotSame(deserializedQuery, testQuery);

View File

@ -27,28 +27,38 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.TreeMap;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.notNullValue;
public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQueryStringBuilder> { public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQueryStringBuilder> {
private String[] queryTerms;
@Override @Override
protected SimpleQueryStringBuilder doCreateTestQueryBuilder() { protected SimpleQueryStringBuilder doCreateTestQueryBuilder() {
SimpleQueryStringBuilder result = new SimpleQueryStringBuilder(randomAsciiOfLengthBetween(1, 10)); int numberOfTerms = randomIntBetween(1, 5);
queryTerms = new String[numberOfTerms];
StringBuilder queryString = new StringBuilder();
for (int i = 0; i < numberOfTerms; i++) {
queryTerms[i] = randomAsciiOfLengthBetween(1, 10);
queryString.append(queryTerms[i] + " ");
}
SimpleQueryStringBuilder result = new SimpleQueryStringBuilder(queryString.toString().trim());
if (randomBoolean()) { if (randomBoolean()) {
result.analyzeWildcard(randomBoolean()); result.analyzeWildcard(randomBoolean());
} }
@ -72,9 +82,13 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
} }
if (randomBoolean()) { if (randomBoolean()) {
Set<SimpleQueryStringFlag> flagSet = new HashSet<>(); Set<SimpleQueryStringFlag> flagSet = new HashSet<>();
if (numberOfTerms > 1) {
flagSet.add(SimpleQueryStringFlag.WHITESPACE);
}
int size = randomIntBetween(0, SimpleQueryStringFlag.values().length); int size = randomIntBetween(0, SimpleQueryStringFlag.values().length);
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
flagSet.add(randomFrom(SimpleQueryStringFlag.values())); SimpleQueryStringFlag randomFlag = randomFrom(SimpleQueryStringFlag.values());
flagSet.add(randomFlag);
} }
if (flagSet.size() > 0) { if (flagSet.size() > 0) {
result.flags(flagSet.toArray(new SimpleQueryStringFlag[flagSet.size()])); result.flags(flagSet.toArray(new SimpleQueryStringFlag[flagSet.size()]));
@ -85,13 +99,12 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
Map<String, Float> fields = new HashMap<>(); Map<String, Float> fields = new HashMap<>();
for (int i = 0; i < fieldCount; i++) { for (int i = 0; i < fieldCount; i++) {
if (randomBoolean()) { if (randomBoolean()) {
fields.put(randomAsciiOfLengthBetween(1, 10), AbstractQueryBuilder.DEFAULT_BOOST); fields.put("f" + i + "_" + randomAsciiOfLengthBetween(1, 10), AbstractQueryBuilder.DEFAULT_BOOST);
} else { } else {
fields.put(randomBoolean() ? STRING_FIELD_NAME : randomAsciiOfLengthBetween(1, 10), 2.0f / randomIntBetween(1, 20)); fields.put(randomBoolean() ? STRING_FIELD_NAME : "f" + i + "_" + randomAsciiOfLengthBetween(1, 10), 2.0f / randomIntBetween(1, 20));
} }
} }
result.fields(fields); result.fields(fields);
return result; return result;
} }
@ -256,8 +269,8 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
// no strict field resolution (version before V_1_4_0_Beta1) // no strict field resolution (version before V_1_4_0_Beta1)
if (getCurrentTypes().length > 0 || shardContext.indexQueryParserService().getIndexCreatedVersion().before(Version.V_1_4_0_Beta1)) { if (getCurrentTypes().length > 0 || shardContext.indexQueryParserService().getIndexCreatedVersion().before(Version.V_1_4_0_Beta1)) {
Query luceneQuery = queryBuilder.toQuery(shardContext); Query luceneQuery = queryBuilder.toQuery(shardContext);
assertThat(luceneQuery, instanceOf(TermQuery.class)); assertThat(luceneQuery, instanceOf(BooleanQuery.class));
TermQuery termQuery = (TermQuery) luceneQuery; TermQuery termQuery = (TermQuery) ((BooleanQuery) luceneQuery).clauses().get(0).getQuery();
assertThat(termQuery.getTerm(), equalTo(new Term(MetaData.ALL, query))); assertThat(termQuery.getTerm(), equalTo(new Term(MetaData.ALL, query)));
} }
} }
@ -275,7 +288,7 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
if ("".equals(queryBuilder.value())) { if ("".equals(queryBuilder.value())) {
assertTrue("Query should have been MatchNoDocsQuery but was " + query.getClass().getName(), query instanceof MatchNoDocsQuery); assertTrue("Query should have been MatchNoDocsQuery but was " + query.getClass().getName(), query instanceof MatchNoDocsQuery);
} else if (queryBuilder.fields().size() > 1) { } else {
assertTrue("Query should have been BooleanQuery but was " + query.getClass().getName(), query instanceof BooleanQuery); assertTrue("Query should have been BooleanQuery but was " + query.getClass().getName(), query instanceof BooleanQuery);
BooleanQuery boolQuery = (BooleanQuery) query; BooleanQuery boolQuery = (BooleanQuery) query;
@ -288,32 +301,42 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
} }
} }
assertThat(boolQuery.clauses().size(), equalTo(queryBuilder.fields().size())); assertThat(boolQuery.clauses().size(), equalTo(queryTerms.length));
Iterator<String> fields = queryBuilder.fields().keySet().iterator(); Map<String, Float> expectedFields = new TreeMap<String, Float>(queryBuilder.fields());
for (BooleanClause booleanClause : boolQuery) { if (expectedFields.size() == 0) {
expectedFields.put(MetaData.ALL, AbstractQueryBuilder.DEFAULT_BOOST);
}
for (int i = 0; i < queryTerms.length; i++) {
BooleanClause booleanClause = boolQuery.clauses().get(i);
Iterator<Entry<String, Float>> fieldsIter = expectedFields.entrySet().iterator();
if (queryTerms.length == 1 && expectedFields.size() == 1) {
assertThat(booleanClause.getQuery(), instanceOf(TermQuery.class)); assertThat(booleanClause.getQuery(), instanceOf(TermQuery.class));
TermQuery termQuery = (TermQuery) booleanClause.getQuery(); TermQuery termQuery = (TermQuery) booleanClause.getQuery();
assertThat(termQuery.getTerm().field(), equalTo(fields.next())); Entry<String, Float> entry = fieldsIter.next();
assertThat(termQuery.getTerm().text().toLowerCase(Locale.ROOT), equalTo(queryBuilder.value().toLowerCase(Locale.ROOT))); assertThat(termQuery.getTerm().field(), equalTo(entry.getKey()));
assertThat(termQuery.getBoost(), equalTo(entry.getValue()));
assertThat(termQuery.getTerm().text().toLowerCase(Locale.ROOT), equalTo(queryTerms[i].toLowerCase(Locale.ROOT)));
} else {
assertThat(booleanClause.getQuery(), instanceOf(BooleanQuery.class));
for (BooleanClause clause : ((BooleanQuery) booleanClause.getQuery()).clauses()) {
TermQuery termQuery = (TermQuery) clause.getQuery();
Entry<String, Float> entry = fieldsIter.next();
assertThat(termQuery.getTerm().field(), equalTo(entry.getKey()));
assertThat(termQuery.getBoost(), equalTo(entry.getValue()));
assertThat(termQuery.getTerm().text().toLowerCase(Locale.ROOT), equalTo(queryTerms[i].toLowerCase(Locale.ROOT)));
}
}
} }
if (queryBuilder.minimumShouldMatch() != null) { if (queryBuilder.minimumShouldMatch() != null) {
assertThat(boolQuery.getMinimumNumberShouldMatch(), greaterThan(0)); int optionalClauses = queryTerms.length;
if (queryBuilder.defaultOperator().equals(Operator.AND) && queryTerms.length > 1) {
optionalClauses = 0;
} }
} else if (queryBuilder.fields().size() <= 1) { int expectedMinimumShouldMatch = Queries.calculateMinShouldMatch(optionalClauses, queryBuilder.minimumShouldMatch());
assertTrue("Query should have been TermQuery but was " + query.getClass().getName(), query instanceof TermQuery); assertEquals(expectedMinimumShouldMatch, boolQuery.getMinimumNumberShouldMatch());
TermQuery termQuery = (TermQuery) query;
String field;
if (queryBuilder.fields().size() == 0) {
field = MetaData.ALL;
} else {
field = queryBuilder.fields().keySet().iterator().next();
} }
assertThat(termQuery.getTerm().field(), equalTo(field));
assertThat(termQuery.getTerm().text().toLowerCase(Locale.ROOT), equalTo(queryBuilder.value().toLowerCase(Locale.ROOT)));
} else {
fail("Encountered lucene query type we do not have a validation implementation for in our " + SimpleQueryStringBuilderTests.class.getSimpleName());
} }
} }
@ -339,15 +362,18 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
SimpleQueryStringBuilder simpleQueryStringBuilder = new SimpleQueryStringBuilder("test"); SimpleQueryStringBuilder simpleQueryStringBuilder = new SimpleQueryStringBuilder("test");
simpleQueryStringBuilder.field(STRING_FIELD_NAME, 5); simpleQueryStringBuilder.field(STRING_FIELD_NAME, 5);
Query query = simpleQueryStringBuilder.toQuery(shardContext); Query query = simpleQueryStringBuilder.toQuery(shardContext);
assertThat(query, instanceOf(TermQuery.class)); assertThat(query, instanceOf(BooleanQuery.class));
assertThat(query.getBoost(), equalTo(5f)); TermQuery wrappedQuery = (TermQuery) ((BooleanQuery) query).clauses().get(0).getQuery();
assertThat(wrappedQuery.getBoost(), equalTo(5f));
simpleQueryStringBuilder = new SimpleQueryStringBuilder("test"); simpleQueryStringBuilder = new SimpleQueryStringBuilder("test");
simpleQueryStringBuilder.field(STRING_FIELD_NAME, 5); simpleQueryStringBuilder.field(STRING_FIELD_NAME, 5);
simpleQueryStringBuilder.boost(2); simpleQueryStringBuilder.boost(2);
query = simpleQueryStringBuilder.toQuery(shardContext); query = simpleQueryStringBuilder.toQuery(shardContext);
assertThat(query, instanceOf(TermQuery.class)); assertThat(query.getBoost(), equalTo(2f));
assertThat(query.getBoost(), equalTo(10f)); assertThat(query, instanceOf(BooleanQuery.class));
wrappedQuery = (TermQuery) ((BooleanQuery) query).clauses().get(0).getQuery();
assertThat(wrappedQuery.getBoost(), equalTo(5f));
} }
public void testNegativeFlags() throws IOException { public void testNegativeFlags() throws IOException {
@ -359,4 +385,39 @@ public class SimpleQueryStringBuilderTests extends AbstractQueryTestCase<SimpleQ
otherBuilder.flags(-1); otherBuilder.flags(-1);
assertThat(builder, equalTo(otherBuilder)); assertThat(builder, equalTo(otherBuilder));
} }
public void testMinimumShouldMatch() throws IOException {
QueryShardContext shardContext = createShardContext();
int numberOfTerms = randomIntBetween(1, 4);
int numberOfFields = randomIntBetween(1, 4);
StringBuilder queryString = new StringBuilder();
for (int i = 0; i < numberOfTerms; i++) {
queryString.append("t" + i + " ");
}
SimpleQueryStringBuilder simpleQueryStringBuilder = new SimpleQueryStringBuilder(queryString.toString().trim());
if (randomBoolean()) {
simpleQueryStringBuilder.defaultOperator(Operator.AND);
}
for (int i = 0; i < numberOfFields; i++) {
simpleQueryStringBuilder.field("f" + i);
}
int percent = randomIntBetween(1, 100);
simpleQueryStringBuilder.minimumShouldMatch(percent + "%");
BooleanQuery query = (BooleanQuery) simpleQueryStringBuilder.toQuery(shardContext);
assertEquals("query should have one should clause per term", numberOfTerms, query.clauses().size());
int expectedMinimumShouldMatch = numberOfTerms * percent / 100;
if (simpleQueryStringBuilder.defaultOperator().equals(Operator.AND) && numberOfTerms > 1) {
expectedMinimumShouldMatch = 0;
}
assertEquals(expectedMinimumShouldMatch, query.getMinimumNumberShouldMatch());
for (BooleanClause clause : query.clauses()) {
if (numberOfFields == 1 && numberOfTerms == 1) {
assertTrue(clause.getQuery() instanceof TermQuery);
} else {
assertEquals(numberOfFields, ((BooleanQuery) clause.getQuery()).clauses().size());
}
}
}
} }

View File

@ -109,7 +109,6 @@ public class SimpleQueryStringIT extends ESIntegTestCase {
client().prepareIndex("test", "type1", "3").setSource("body", "foo bar"), client().prepareIndex("test", "type1", "3").setSource("body", "foo bar"),
client().prepareIndex("test", "type1", "4").setSource("body", "foo baz bar")); client().prepareIndex("test", "type1", "4").setSource("body", "foo baz bar"));
logger.info("--> query 1"); logger.info("--> query 1");
SearchResponse searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar").minimumShouldMatch("2")).get(); SearchResponse searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar").minimumShouldMatch("2")).get();
assertHitCount(searchResponse, 2l); assertHitCount(searchResponse, 2l);
@ -120,7 +119,13 @@ public class SimpleQueryStringIT extends ESIntegTestCase {
assertHitCount(searchResponse, 2l); assertHitCount(searchResponse, 2l);
assertSearchHits(searchResponse, "3", "4"); assertSearchHits(searchResponse, "3", "4");
logger.info("--> query 3"); logger.info("--> query 3"); // test case from #13884
searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo")
.field("body").field("body2").field("body3").minimumShouldMatch("-50%")).get();
assertHitCount(searchResponse, 3l);
assertSearchHits(searchResponse, "1", "3", "4");
logger.info("--> query 4");
searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar baz").field("body").field("body2").minimumShouldMatch("70%")).get(); searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar baz").field("body").field("body2").minimumShouldMatch("70%")).get();
assertHitCount(searchResponse, 2l); assertHitCount(searchResponse, 2l);
assertSearchHits(searchResponse, "3", "4"); assertSearchHits(searchResponse, "3", "4");
@ -131,17 +136,17 @@ public class SimpleQueryStringIT extends ESIntegTestCase {
client().prepareIndex("test", "type1", "7").setSource("body2", "foo bar", "other", "foo"), client().prepareIndex("test", "type1", "7").setSource("body2", "foo bar", "other", "foo"),
client().prepareIndex("test", "type1", "8").setSource("body2", "foo baz bar", "other", "foo")); client().prepareIndex("test", "type1", "8").setSource("body2", "foo baz bar", "other", "foo"));
logger.info("--> query 4"); logger.info("--> query 5");
searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar").field("body").field("body2").minimumShouldMatch("2")).get(); searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar").field("body").field("body2").minimumShouldMatch("2")).get();
assertHitCount(searchResponse, 4l); assertHitCount(searchResponse, 4l);
assertSearchHits(searchResponse, "3", "4", "7", "8"); assertSearchHits(searchResponse, "3", "4", "7", "8");
logger.info("--> query 5"); logger.info("--> query 6");
searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar").minimumShouldMatch("2")).get(); searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar").minimumShouldMatch("2")).get();
assertHitCount(searchResponse, 5l); assertHitCount(searchResponse, 5l);
assertSearchHits(searchResponse, "3", "4", "6", "7", "8"); assertSearchHits(searchResponse, "3", "4", "6", "7", "8");
logger.info("--> query 6"); logger.info("--> query 7");
searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar baz").field("body2").field("other").minimumShouldMatch("70%")).get(); searchResponse = client().prepareSearch().setQuery(simpleQueryStringQuery("foo bar baz").field("body2").field("other").minimumShouldMatch("70%")).get();
assertHitCount(searchResponse, 3l); assertHitCount(searchResponse, 3l);
assertSearchHits(searchResponse, "6", "7", "8"); assertSearchHits(searchResponse, "6", "7", "8");