Search: When sorting, allow to pass `track_scores` and set it to `true` to get scores/max_score back, closes #662.

This commit is contained in:
kimchy 2011-02-01 12:38:46 +02:00
parent 5da14a7ed1
commit cc6f65f8b8
8 changed files with 145 additions and 10 deletions

View File

@ -350,6 +350,15 @@ public class SearchRequestBuilder extends BaseRequestBuilder<SearchRequest, Sear
return this;
}
/**
* Applies when sorting, and controls if scores will be tracked as well. Defaults to
* <tt>false</tt>.
*/
public SearchRequestBuilder setTrackScores(boolean trackScores) {
sourceBuilder().trackScores(trackScores);
return this;
}
/**
* Adds the fields to load and return as part of the search request. If none are specified,
* the source of the document will be returned.

View File

@ -84,6 +84,8 @@ public class SearchSourceBuilder implements ToXContent {
private List<SortBuilder> sorts;
private boolean trackScores = false;
private List<String> fieldNames;
private List<ScriptField> scriptFields;
@ -219,6 +221,15 @@ public class SearchSourceBuilder implements ToXContent {
return this;
}
/**
* Applies when sorting, and controls if scores will be tracked as well. Defaults to
* <tt>false</tt>.
*/
public SearchSourceBuilder trackScores(boolean trackScores) {
this.trackScores = trackScores;
return this;
}
/**
* Add a facet to perform as part of the search.
*/
@ -453,6 +464,9 @@ public class SearchSourceBuilder implements ToXContent {
builder.endObject();
}
builder.endArray();
if (trackScores) {
builder.field("track_scores", trackScores);
}
}
if (indexBoost != null) {

View File

@ -19,6 +19,7 @@
package org.elasticsearch.search.internal;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.*;
import org.elasticsearch.common.collect.Lists;
import org.elasticsearch.common.collect.Maps;
@ -43,7 +44,9 @@ public class ContextIndexSearcher extends ExtendedIndexSearcher {
public static final String NA = "_na_";
}
private SearchContext searchContext;
private final SearchContext searchContext;
private final IndexReader reader;
private CachedDfSource dfSource;
@ -54,6 +57,7 @@ public class ContextIndexSearcher extends ExtendedIndexSearcher {
public ContextIndexSearcher(SearchContext searchContext, Engine.Searcher searcher) {
super(searcher.searcher());
this.searchContext = searchContext;
this.reader = searcher.searcher().getIndexReader();
}
public void dfSource(CachedDfSource dfSource) {
@ -116,6 +120,18 @@ public class ContextIndexSearcher extends ExtendedIndexSearcher {
return query.weight(dfSource);
}
// override from the Searcher to allow to control if scores will be tracked or not
@Override public TopFieldDocs search(Weight weight, Filter filter, int nDocs,
Sort sort, boolean fillFields) throws IOException {
nDocs = Math.min(nDocs, reader.maxDoc());
TopFieldCollector collector = TopFieldCollector.create(sort, nDocs,
fillFields, searchContext.trackScores(), searchContext.trackScores(), !weight.scoresDocsOutOfOrder());
search(weight, filter, collector);
return (TopFieldDocs) collector.topDocs();
}
@Override public void search(Weight weight, Filter filter, Collector collector) throws IOException {
if (searchContext.parsedFilter() != null) {
// this will only get applied to the actual search collector and not

View File

@ -110,6 +110,8 @@ public class SearchContext implements Releasable {
private Sort sort;
private boolean trackScores = false; // when sorting, track scores as well...
private String queryParserName;
private ParsedQuery originalQuery;
@ -294,6 +296,15 @@ public class SearchContext implements Releasable {
return this.sort;
}
public SearchContext trackScores(boolean trackScores) {
this.trackScores = trackScores;
return this;
}
public boolean trackScores() {
return this.trackScores;
}
public SearchContext parsedFilter(Filter filter) {
this.filter = filter;
return this;

View File

@ -33,6 +33,7 @@ import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.ScopePhase;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.SortParseElement;
import org.elasticsearch.search.sort.TrackScoresParseElement;
import java.util.Map;
@ -61,6 +62,8 @@ public class QueryPhase implements SearchPhase {
.put("filterBinary", new FilterBinaryParseElement())
.put("filter_binary", new FilterBinaryParseElement())
.put("sort", new SortParseElement())
.put("trackScores", new TrackScoresParseElement())
.put("track_scores", new TrackScoresParseElement())
.putAll(facetPhase.parseElements());
return parseElements.build();
}

View File

@ -0,0 +1,37 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Elastic Search licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.sort;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.internal.SearchContext;
/**
* @author kimchy (shay.banon)
*/
public class TrackScoresParseElement implements SearchParseElement {
@Override public void parse(XContentParser parser, SearchContext context) throws Exception {
XContentParser.Token token = parser.currentToken();
if (token.isValue()) {
context.trackScores(parser.booleanValue());
}
}
}

View File

@ -22,12 +22,15 @@ package org.elasticsearch.test.integration.search.sort;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.integration.AbstractNodesTests;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import static org.elasticsearch.common.settings.ImmutableSettings.*;
import static org.elasticsearch.common.xcontent.XContentFactory.*;
import static org.elasticsearch.index.query.xcontent.QueryBuilders.*;
import static org.elasticsearch.search.sort.SortBuilders.*;
@ -42,8 +45,9 @@ public class SimpleSortTests extends AbstractNodesTests {
private Client client;
@BeforeClass public void createNodes() throws Exception {
startNode("server1");
startNode("server2");
Settings settings = settingsBuilder().put("number_of_shards", 3).put("number_of_replicas", 0).build();
startNode("server1", settings);
startNode("server2", settings);
client = getClient();
}
@ -56,6 +60,54 @@ public class SimpleSortTests extends AbstractNodesTests {
return client("server1");
}
@Test public void testTrackScores() throws Exception {
try {
client.admin().indices().prepareDelete("test").execute().actionGet();
} catch (Exception e) {
// ignore
}
client.admin().indices().prepareCreate("test").execute().actionGet();
client.admin().cluster().prepareHealth().setWaitForGreenStatus().execute().actionGet();
client.prepareIndex("test", "type1").setSource(jsonBuilder().startObject()
.field("id", "1")
.field("svalue", "aaa")
.field("ivalue", 100)
.field("dvalue", 0.1)
.endObject()).execute().actionGet();
client.prepareIndex("test", "type1").setSource(jsonBuilder().startObject()
.field("id", "2")
.field("svalue", "bbb")
.field("ivalue", 200)
.field("dvalue", 0.2)
.endObject()).execute().actionGet();
client.admin().indices().prepareFlush().setRefresh(true).execute().actionGet();
SearchResponse searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addSort("svalue", SortOrder.ASC)
.execute().actionGet();
assertThat(searchResponse.hits().getMaxScore(), equalTo(Float.NaN));
for (SearchHit hit : searchResponse.hits()) {
assertThat(hit.getScore(), equalTo(Float.NaN));
}
// now check with score tracking
searchResponse = client.prepareSearch()
.setQuery(matchAllQuery())
.addSort("svalue", SortOrder.ASC)
.setTrackScores(true)
.execute().actionGet();
assertThat(searchResponse.hits().getMaxScore(), not(equalTo(Float.NaN)));
for (SearchHit hit : searchResponse.hits()) {
assertThat(hit.getScore(), not(equalTo(Float.NaN)));
}
}
@Test public void testSimpleSorts() throws Exception {
try {
client.admin().indices().prepareDelete("test").execute().actionGet();

View File

@ -1,7 +0,0 @@
cluster:
routing:
schedule: 100ms
index:
number_of_shards: 3
number_of_replicas: 0
routing :