Fold MoreLikeThisFetchService into MoreLikeThisQueryParser

now that we have a Client on the Shard context we can fold
the doc fetching into the parser / builder.

Relates to #13488
This commit is contained in:
Simon Willnauer 2015-09-11 09:42:25 +02:00
parent 94a37d486f
commit 853b7fdb7c
5 changed files with 127 additions and 162 deletions

View File

@ -21,25 +21,23 @@ package org.elasticsearch.index.query;
import com.google.common.collect.Sets;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.Fields;
import org.apache.lucene.queries.TermsQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.termvectors.MultiTermVectorsRequest;
import org.elasticsearch.action.termvectors.MultiTermVectorsResponse;
import org.elasticsearch.action.termvectors.TermVectorsRequest;
import org.elasticsearch.action.termvectors.*;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.lucene.search.MoreLikeThisQuery;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.analysis.Analysis;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.internal.UidFieldMapper;
import org.elasticsearch.index.query.MoreLikeThisQueryBuilder.Item;
import org.elasticsearch.index.search.morelikethis.MoreLikeThisFetchService;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
@ -54,8 +52,6 @@ import static org.elasticsearch.index.mapper.Uid.createUidAsBytes;
*/
public class MoreLikeThisQueryParser extends BaseQueryParserTemp {
private MoreLikeThisFetchService fetchService = null;
public interface Field {
ParseField FIELDS = new ParseField("fields");
ParseField LIKE = new ParseField("like");
@ -77,15 +73,6 @@ public class MoreLikeThisQueryParser extends BaseQueryParserTemp {
ParseField FAIL_ON_UNSUPPORTED_FIELD = new ParseField("fail_on_unsupported_field");
}
public MoreLikeThisQueryParser() {
}
@Inject(optional = true)
public void setFetchService(@Nullable MoreLikeThisFetchService fetchService) {
this.fetchService = fetchService;
}
@Override
public String[] names() {
return new String[]{MoreLikeThisQueryBuilder.NAME, "more_like_this", "moreLikeThis"};
@ -251,7 +238,7 @@ public class MoreLikeThisQueryParser extends BaseQueryParserTemp {
// handle items
if (!likeItems.isEmpty()) {
return handleItems(parseContext, mltQuery, likeItems, unlikeItems, include, moreLikeFields, useDefaultField);
return handleItems(context, mltQuery, likeItems, unlikeItems, include, moreLikeFields, useDefaultField);
} else {
return mltQuery;
}
@ -282,8 +269,10 @@ public class MoreLikeThisQueryParser extends BaseQueryParserTemp {
return moreLikeFields;
}
private Query handleItems(QueryParseContext parseContext, MoreLikeThisQuery mltQuery, List<Item> likeItems, List<Item> unlikeItems,
private Query handleItems(QueryShardContext context, MoreLikeThisQuery mltQuery, List<Item> likeItems, List<Item> unlikeItems,
boolean include, List<String> moreLikeFields, boolean useDefaultField) throws IOException {
QueryParseContext parseContext = context.parseContext();
// set default index, type and fields if not specified
for (Item item : likeItems) {
setDefaultIndexTypeFields(parseContext, item, moreLikeFields, useDefaultField);
@ -293,14 +282,14 @@ public class MoreLikeThisQueryParser extends BaseQueryParserTemp {
}
// fetching the items with multi-termvectors API
MultiTermVectorsResponse responses = fetchService.fetchResponse(likeItems, unlikeItems, SearchContext.current());
MultiTermVectorsResponse responses = fetchResponse(context.getClient(), likeItems, unlikeItems, SearchContext.current());
// getting the Fields for liked items
mltQuery.setLikeText(MoreLikeThisFetchService.getFieldsFor(responses, likeItems));
mltQuery.setLikeText(getFieldsFor(responses, likeItems));
// getting the Fields for unliked items
if (!unlikeItems.isEmpty()) {
org.apache.lucene.index.Fields[] unlikeFields = MoreLikeThisFetchService.getFieldsFor(responses, unlikeItems);
org.apache.lucene.index.Fields[] unlikeFields = getFieldsFor(responses, unlikeItems);
if (unlikeFields.length > 0) {
mltQuery.setUnlikeText(unlikeFields);
}
@ -358,4 +347,47 @@ public class MoreLikeThisQueryParser extends BaseQueryParserTemp {
public MoreLikeThisQueryBuilder getBuilderPrototype() {
return MoreLikeThisQueryBuilder.PROTOTYPE;
}
private MultiTermVectorsResponse fetchResponse(Client client, List<Item> likeItems, @Nullable List<Item> unlikeItems,
SearchContext searchContext) throws IOException {
MultiTermVectorsRequest request = new MultiTermVectorsRequest();
for (Item item : likeItems) {
request.add(item.toTermVectorsRequest());
}
if (unlikeItems != null) {
for (Item item : unlikeItems) {
request.add(item.toTermVectorsRequest());
}
}
request.copyContextAndHeadersFrom(searchContext);
return client.multiTermVectors(request).actionGet();
}
private static Fields[] getFieldsFor(MultiTermVectorsResponse responses, List<Item> items) throws IOException {
List<Fields> likeFields = new ArrayList<>();
Set<Item> selectedItems = new HashSet<>();
for (Item request : items) {
selectedItems.add(new Item(request.index(), request.type(), request.id()));
}
for (MultiTermVectorsItemResponse response : responses) {
if (!hasResponseFromRequest(response, selectedItems)) {
continue;
}
if (response.isFailed()) {
continue;
}
TermVectorsResponse getResponse = response.getResponse();
if (!getResponse.isExists()) {
continue;
}
likeFields.add(getResponse.getFields());
}
return likeFields.toArray(Fields.EMPTY_ARRAY);
}
private static boolean hasResponseFromRequest(MultiTermVectorsItemResponse response, Set<Item> selectedItems) {
return selectedItems.contains(new Item(response.getIndex(), response.getType(), response.getId()));
}
}

View File

@ -1,100 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.search.morelikethis;
import org.apache.lucene.index.Fields;
import org.elasticsearch.action.termvectors.MultiTermVectorsItemResponse;
import org.elasticsearch.action.termvectors.MultiTermVectorsRequest;
import org.elasticsearch.action.termvectors.MultiTermVectorsResponse;
import org.elasticsearch.action.termvectors.TermVectorsResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.MoreLikeThisQueryBuilder.Item;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
*
*/
public class MoreLikeThisFetchService extends AbstractComponent {
private final Client client;
@Inject
public MoreLikeThisFetchService(Client client, Settings settings) {
super(settings);
this.client = client;
}
public Fields[] fetch(List<Item> items) throws IOException {
return getFieldsFor(fetchResponse(items, null, SearchContext.current()), items);
}
public MultiTermVectorsResponse fetchResponse(List<Item> likeItems, @Nullable List<Item> unlikeItems,
SearchContext searchContext) throws IOException {
MultiTermVectorsRequest request = new MultiTermVectorsRequest();
for (Item item : likeItems) {
request.add(item.toTermVectorsRequest());
}
if (unlikeItems != null) {
for (Item item : unlikeItems) {
request.add(item.toTermVectorsRequest());
}
}
request.copyContextAndHeadersFrom(searchContext);
return client.multiTermVectors(request).actionGet();
}
public static Fields[] getFieldsFor(MultiTermVectorsResponse responses, List<Item> items) throws IOException {
List<Fields> likeFields = new ArrayList<>();
Set<Item> selectedItems = new HashSet<>();
for (Item request : items) {
selectedItems.add(new Item(request.index(), request.type(), request.id()));
}
for (MultiTermVectorsItemResponse response : responses) {
if (!hasResponseFromRequest(response, selectedItems)) {
continue;
}
if (response.isFailed()) {
continue;
}
TermVectorsResponse getResponse = response.getResponse();
if (!getResponse.isExists()) {
continue;
}
likeFields.add(getResponse.getFields());
}
return likeFields.toArray(Fields.EMPTY_ARRAY);
}
private static boolean hasResponseFromRequest(MultiTermVectorsItemResponse response, Set<Item> selectedItems) {
return selectedItems.contains(new Item(response.getIndex(), response.getType(), response.getId()));
}
}

View File

@ -24,7 +24,6 @@ import org.elasticsearch.common.inject.multibindings.Multibinder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.functionscore.ScoreFunctionParser;
import org.elasticsearch.index.query.functionscore.ScoreFunctionParserMapper;
import org.elasticsearch.index.search.morelikethis.MoreLikeThisFetchService;
import org.elasticsearch.search.action.SearchServiceTransportAction;
import org.elasticsearch.search.aggregations.AggregationParseElement;
import org.elasticsearch.search.aggregations.AggregationPhase;
@ -339,8 +338,6 @@ public class SearchModule extends AbstractModule {
bind(SearchPhaseController.class).asEagerSingleton();
bind(FetchPhase.class).asEagerSingleton();
bind(SearchServiceTransportAction.class).asEagerSingleton();
bind(MoreLikeThisFetchService.class).asEagerSingleton();
if (searchServiceImpl == SearchService.class) {
bind(SearchService.class).asEagerSingleton();
} else {

View File

@ -28,6 +28,8 @@ import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.termvectors.MultiTermVectorsRequest;
import org.elasticsearch.action.termvectors.MultiTermVectorsResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterService;
import org.elasticsearch.cluster.ClusterState;
@ -634,6 +636,13 @@ public abstract class AbstractQueryTestCase<QB extends AbstractQueryBuilder<QB>>
return delegate.executeGet((GetRequest) args[0]);
}
};
} else if (method.equals(Client.class.getDeclaredMethod("multiTermVectors", MultiTermVectorsRequest.class))) {
return new PlainActionFuture<MultiTermVectorsResponse>() {
@Override
public MultiTermVectorsResponse get() throws InterruptedException, ExecutionException {
return delegate.executeMultiTermVectors((MultiTermVectorsRequest) args[0]);
}
};
} else if (method.equals(Object.class.getDeclaredMethod("toString"))) {
return "MockClient";
}
@ -649,4 +658,11 @@ public abstract class AbstractQueryTestCase<QB extends AbstractQueryBuilder<QB>>
throw new UnsupportedOperationException("this test can't handle GET requests");
}
/**
* Override this to handle {@link Client#get(GetRequest)} calls from parsers / builders
*/
protected MultiTermVectorsResponse executeMultiTermVectors(MultiTermVectorsRequest mtvRequest) {
throw new UnsupportedOperationException("this test can't handle MultiTermVector requests");
}
}

View File

@ -35,10 +35,9 @@ import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.automaton.TooComplexToDeterminizeException;
import org.elasticsearch.action.termvectors.MultiTermVectorsItemResponse;
import org.elasticsearch.action.termvectors.MultiTermVectorsResponse;
import org.elasticsearch.action.termvectors.TermVectorsRequest;
import org.elasticsearch.action.termvectors.TermVectorsResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.termvectors.*;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.lucene.search.MoreLikeThisQuery;
@ -60,7 +59,6 @@ import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.index.search.geo.GeoDistanceRangeQuery;
import org.elasticsearch.index.search.geo.GeoPolygonQuery;
import org.elasticsearch.index.search.geo.InMemoryGeoBoundingBoxQuery;
import org.elasticsearch.index.search.morelikethis.MoreLikeThisFetchService;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.hamcrest.Matchers;
@ -68,10 +66,14 @@ import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.ExecutionException;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.QueryBuilders.*;
@ -1224,7 +1226,7 @@ public class SimpleIndexQueryParserTests extends ESSingleNodeTestCase {
NumericRangeQuery<Long> expectedWrapped = NumericRangeQuery.newLongRange("age", NumberFieldMapper.Defaults.PRECISION_STEP_64_BIT, 7l, 17l, true, true);
expectedWrapped.setBoost(2.0f);
SpanMultiTermQueryWrapper<MultiTermQuery> wrapper = (SpanMultiTermQueryWrapper<MultiTermQuery>) parsedQuery;
assertThat(wrapper, equalTo(new SpanMultiTermQueryWrapper<MultiTermQuery>(expectedWrapped)));
assertThat(wrapper, equalTo(new SpanMultiTermQueryWrapper<>(expectedWrapped)));
}
@Test
@ -1236,7 +1238,7 @@ public class SimpleIndexQueryParserTests extends ESSingleNodeTestCase {
NumericRangeQuery<Long> expectedWrapped = NumericRangeQuery.newLongRange("age", NumberFieldMapper.Defaults.PRECISION_STEP_64_BIT, 10l, 20l, true, false);
expectedWrapped.setBoost(2.0f);
SpanMultiTermQueryWrapper<MultiTermQuery> wrapper = (SpanMultiTermQueryWrapper<MultiTermQuery>) parsedQuery;
assertThat(wrapper, equalTo(new SpanMultiTermQueryWrapper<MultiTermQuery>(expectedWrapped)));
assertThat(wrapper, equalTo(new SpanMultiTermQueryWrapper<>(expectedWrapped)));
}
@Test
@ -1248,7 +1250,7 @@ public class SimpleIndexQueryParserTests extends ESSingleNodeTestCase {
TermRangeQuery expectedWrapped = TermRangeQuery.newStringRange("user", "alice", "bob", true, false);
expectedWrapped.setBoost(2.0f);
SpanMultiTermQueryWrapper<MultiTermQuery> wrapper = (SpanMultiTermQueryWrapper<MultiTermQuery>) parsedQuery;
assertThat(wrapper, equalTo(new SpanMultiTermQueryWrapper<MultiTermQuery>(expectedWrapped)));
assertThat(wrapper, equalTo(new SpanMultiTermQueryWrapper<>(expectedWrapped)));
}
@Test
@ -1279,12 +1281,16 @@ public class SimpleIndexQueryParserTests extends ESSingleNodeTestCase {
@Test
public void testMoreLikeThisIds() throws Exception {
MoreLikeThisQueryParser parser = (MoreLikeThisQueryParser) queryParser.indicesQueriesRegistry().queryParsers().get("more_like_this");
parser.setFetchService(new MockMoreLikeThisFetchService());
IndexQueryParserService queryParser = queryParser();
final Client proxy = getMLTClientProxy();
String query = copyToStringFromClasspath("/org/elasticsearch/index/query/mlt-items.json");
Query parsedQuery = queryParser.parse(query).query();
QueryShardContext ctx = new QueryShardContext(queryParser.index(), queryParser) {
@Override
public Client getClient() {
return proxy;
}
};
Query parsedQuery = queryParser.parse(ctx, new BytesArray(query)).query();
assertThat(parsedQuery, instanceOf(BooleanQuery.class));
BooleanQuery booleanQuery = (BooleanQuery) parsedQuery;
assertThat(booleanQuery.getClauses().length, is(1));
@ -1302,16 +1308,51 @@ public class SimpleIndexQueryParserTests extends ESSingleNodeTestCase {
}
}
private Client getMLTClientProxy() {
return (Client) Proxy.newProxyInstance(
Client.class.getClassLoader(),
new Class[]{Client.class},
(proxy1, method, args) -> {
if (method.equals(Client.class.getDeclaredMethod("multiTermVectors", MultiTermVectorsRequest.class))) {
return new PlainActionFuture<MultiTermVectorsResponse>() {
@Override
public MultiTermVectorsResponse get() throws InterruptedException, ExecutionException {
try {
MultiTermVectorsRequest request = (MultiTermVectorsRequest) args[0];
MultiTermVectorsItemResponse[] responses = new MultiTermVectorsItemResponse[request.size()];
int i = 0;
for (TermVectorsRequest item : request) {
TermVectorsResponse response = new TermVectorsResponse(item.index(), item.type(), item.id());
response.setExists(true);
Fields generatedFields = generateFields(item.selectedFields().toArray(new String[0]), item.id());
EnumSet<TermVectorsRequest.Flag> flags = EnumSet.of(TermVectorsRequest.Flag.Positions, TermVectorsRequest.Flag.Offsets);
response.setFields(generatedFields, item.selectedFields(), flags, generatedFields);
responses[i++] = new MultiTermVectorsItemResponse(response, null);
}
return new MultiTermVectorsResponse(responses);
} catch (IOException ex) {
throw new ExecutionException(ex);
}
}
};
}
throw new UnsupportedOperationException("not supported");
});
}
@Test
public void testMLTMinimumShouldMatch() throws Exception {
// setup for mocking fetching items
MoreLikeThisQueryParser parser = (MoreLikeThisQueryParser) queryParser.indicesQueriesRegistry().queryParsers().get("more_like_this");
parser.setFetchService(new MockMoreLikeThisFetchService());
final Client proxy = getMLTClientProxy();
// parsing the ES query
IndexQueryParserService queryParser = queryParser();
String query = copyToStringFromClasspath("/org/elasticsearch/index/query/mlt-items.json");
BooleanQuery parsedQuery = (BooleanQuery) queryParser.parse(query).query();
QueryShardContext ctx = new QueryShardContext(queryParser.index(), queryParser) {
@Override
public Client getClient() {
return proxy;
}
};
BooleanQuery parsedQuery = (BooleanQuery) queryParser.parse(ctx, new BytesArray(query)).query();
// get MLT query, other clause is for include/exclude items
MoreLikeThisQuery mltQuery = (MoreLikeThisQuery) parsedQuery.getClauses()[0].getQuery();
@ -1339,27 +1380,6 @@ public class SimpleIndexQueryParserTests extends ESSingleNodeTestCase {
assertThat(minNumberShouldMatch, is(2));
}
private static class MockMoreLikeThisFetchService extends MoreLikeThisFetchService {
public MockMoreLikeThisFetchService() {
super(null, Settings.Builder.EMPTY_SETTINGS);
}
@Override
public MultiTermVectorsResponse fetchResponse(List<Item> items, List<Item> unlikeItems, SearchContext searchContext) throws IOException {
MultiTermVectorsItemResponse[] responses = new MultiTermVectorsItemResponse[items.size()];
int i = 0;
for (Item item : items) {
TermVectorsResponse response = new TermVectorsResponse(item.index(), item.type(), item.id());
response.setExists(true);
Fields generatedFields = generateFields(item.fields(), item.id());
EnumSet<TermVectorsRequest.Flag> flags = EnumSet.of(TermVectorsRequest.Flag.Positions, TermVectorsRequest.Flag.Offsets);
response.setFields(generatedFields, new HashSet<>(Arrays.asList(item.fields())), flags, generatedFields);
responses[i++] = new MultiTermVectorsItemResponse(response, null);
}
return new MultiTermVectorsResponse(responses);
}
}
private static Fields generateFields(String[] fieldNames, String text) throws IOException {
MemoryIndex index = new MemoryIndex();