Return the maxScore per search and score for each search hit, closes #205.

This commit is contained in:
kimchy 2010-06-20 00:23:27 +03:00
parent 0f2147aeec
commit 384f8a4f42
11 changed files with 75 additions and 7 deletions

View File

@ -35,6 +35,16 @@ import java.util.Map;
*/ */
public interface SearchHit extends Streamable, ToXContent, Iterable<SearchHitField> { public interface SearchHit extends Streamable, ToXContent, Iterable<SearchHitField> {
/**
* The score.
*/
float score();
/**
* The score.
*/
float getScore();
/** /**
* The index of the hit. * The index of the hit.
*/ */

View File

@ -39,11 +39,24 @@ public interface SearchHits extends Streamable, ToXContent, Iterable<SearchHit>
*/ */
long getTotalHits(); long getTotalHits();
/**
* The maximum score of this query.
*/
float maxScore();
/**
* The maximum score of this query.
*/
float getMaxScore();
/** /**
* The hits of the search request (based on the search type, and from / size provided). * The hits of the search request (based on the search type, and from / size provided).
*/ */
SearchHit[] hits(); SearchHit[] hits();
/**
* Return the hit as the provided position.
*/
SearchHit getAt(int position); SearchHit getAt(int position);
/** /**

View File

@ -169,8 +169,10 @@ public class SearchPhaseController {
// count the total (we use the query result provider here, since we might not get any hits (we scrolled past them)) // count the total (we use the query result provider here, since we might not get any hits (we scrolled past them))
long totalHits = 0; long totalHits = 0;
float maxScore = Float.NEGATIVE_INFINITY;
for (QuerySearchResultProvider queryResultProvider : queryResults.values()) { for (QuerySearchResultProvider queryResultProvider : queryResults.values()) {
totalHits += queryResultProvider.queryResult().topDocs().totalHits; totalHits += queryResultProvider.queryResult().topDocs().totalHits;
maxScore = Math.max(maxScore, queryResultProvider.queryResult().topDocs().getMaxScore());
} }
// clean the fetch counter // clean the fetch counter
@ -190,12 +192,13 @@ public class SearchPhaseController {
int index = fetchResult.counterGetAndIncrement(); int index = fetchResult.counterGetAndIncrement();
if (index < fetchResult.hits().internalHits().length) { if (index < fetchResult.hits().internalHits().length) {
InternalSearchHit searchHit = fetchResult.hits().internalHits()[index]; InternalSearchHit searchHit = fetchResult.hits().internalHits()[index];
searchHit.score(shardDoc.score());
searchHit.shard(fetchResult.shardTarget()); searchHit.shard(fetchResult.shardTarget());
hits.add(searchHit); hits.add(searchHit);
} }
} }
} }
InternalSearchHits searchHits = new InternalSearchHits(hits.toArray(new InternalSearchHit[hits.size()]), totalHits); InternalSearchHits searchHits = new InternalSearchHits(hits.toArray(new InternalSearchHit[hits.size()]), totalHits, maxScore);
return new InternalSearchResponse(searchHits, facets); return new InternalSearchResponse(searchHits, facets);
} }
} }

View File

@ -22,11 +22,13 @@ package org.elasticsearch.search.controller;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
/** /**
* @author kimchy (Shay Banon) * @author kimchy (shay.banon)
*/ */
public interface ShardDoc { public interface ShardDoc {
SearchShardTarget shardTarget(); SearchShardTarget shardTarget();
int docId(); int docId();
float score();
} }

View File

@ -23,7 +23,7 @@ import org.apache.lucene.search.FieldDoc;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
/** /**
* @author kimchy (Shay Banon) * @author kimchy (shay.banon)
*/ */
public class ShardFieldDoc extends FieldDoc implements ShardDoc { public class ShardFieldDoc extends FieldDoc implements ShardDoc {
@ -46,4 +46,8 @@ public class ShardFieldDoc extends FieldDoc implements ShardDoc {
@Override public int docId() { @Override public int docId() {
return this.doc; return this.doc;
} }
@Override public float score() {
return score;
}
} }

View File

@ -41,4 +41,8 @@ public class ShardScoreDoc extends ScoreDoc implements ShardDoc {
@Override public int docId() { @Override public int docId() {
return doc; return doc;
} }
@Override public float score() {
return score;
}
} }

View File

@ -150,7 +150,7 @@ public class FetchPhase implements SearchPhase {
doExplanation(context, docId, searchHit); doExplanation(context, docId, searchHit);
} }
context.fetchResult().hits(new InternalSearchHits(hits, context.queryResult().topDocs().totalHits)); context.fetchResult().hits(new InternalSearchHits(hits, context.queryResult().topDocs().totalHits, context.queryResult().topDocs().getMaxScore()));
highlightPhase.execute(context); highlightPhase.execute(context);
} }

View File

@ -51,6 +51,8 @@ public class InternalSearchHit implements SearchHit {
private transient int docId; private transient int docId;
private float score = Float.NEGATIVE_INFINITY;
private String id; private String id;
private String type; private String type;
@ -83,6 +85,18 @@ public class InternalSearchHit implements SearchHit {
return this.docId; return this.docId;
} }
public void score(float score) {
this.score = score;
}
@Override public float score() {
return this.score;
}
@Override public float getScore() {
return score();
}
@Override public String index() { @Override public String index() {
return shard.index(); return shard.index();
} }
@ -208,6 +222,7 @@ public class InternalSearchHit implements SearchHit {
// builder.field("_node", shard.nodeId()); // builder.field("_node", shard.nodeId());
builder.field("_type", type()); builder.field("_type", type());
builder.field("_id", id()); builder.field("_id", id());
builder.field("_score", score);
if (source() != null) { if (source() != null) {
if (XContentFactory.xContentType(source()) == builder.contentType()) { if (XContentFactory.xContentType(source()) == builder.contentType()) {
builder.rawField("_source", source()); builder.rawField("_source", source());
@ -290,6 +305,7 @@ public class InternalSearchHit implements SearchHit {
} }
public void readFrom(StreamInput in, @Nullable TIntObjectHashMap<SearchShardTarget> shardLookupMap) throws IOException { public void readFrom(StreamInput in, @Nullable TIntObjectHashMap<SearchShardTarget> shardLookupMap) throws IOException {
score = in.readFloat();
id = in.readUTF(); id = in.readUTF();
type = in.readUTF(); type = in.readUTF();
int size = in.readVInt(); int size = in.readVInt();
@ -384,6 +400,7 @@ public class InternalSearchHit implements SearchHit {
} }
public void writeTo(StreamOutput out, @Nullable Map<SearchShardTarget, Integer> shardLookupMap) throws IOException { public void writeTo(StreamOutput out, @Nullable Map<SearchShardTarget, Integer> shardLookupMap) throws IOException {
out.writeFloat(score);
out.writeUTF(id); out.writeUTF(id);
out.writeUTF(type); out.writeUTF(type);
if (source == null) { if (source == null) {

View File

@ -47,13 +47,16 @@ public class InternalSearchHits implements SearchHits {
private long totalHits; private long totalHits;
private float maxScore;
InternalSearchHits() { InternalSearchHits() {
} }
public InternalSearchHits(InternalSearchHit[] hits, long totalHits) { public InternalSearchHits(InternalSearchHit[] hits, long totalHits, float maxScore) {
this.hits = hits; this.hits = hits;
this.totalHits = totalHits; this.totalHits = totalHits;
this.maxScore = maxScore;
} }
public long totalHits() { public long totalHits() {
@ -64,6 +67,14 @@ public class InternalSearchHits implements SearchHits {
return totalHits(); return totalHits();
} }
@Override public float maxScore() {
return this.maxScore;
}
@Override public float getMaxScore() {
return maxScore();
}
public SearchHit[] hits() { public SearchHit[] hits() {
return this.hits; return this.hits;
} }
@ -87,6 +98,7 @@ public class InternalSearchHits implements SearchHits {
@Override public void toXContent(XContentBuilder builder, Params params) throws IOException { @Override public void toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("hits"); builder.startObject("hits");
builder.field("total", totalHits); builder.field("total", totalHits);
builder.field("max_score", maxScore);
builder.field("hits"); builder.field("hits");
builder.startArray(); builder.startArray();
for (SearchHit hit : hits) { for (SearchHit hit : hits) {
@ -104,6 +116,7 @@ public class InternalSearchHits implements SearchHits {
@Override public void readFrom(StreamInput in) throws IOException { @Override public void readFrom(StreamInput in) throws IOException {
totalHits = in.readVLong(); totalHits = in.readVLong();
maxScore = in.readFloat();
int size = in.readVInt(); int size = in.readVInt();
if (size == 0) { if (size == 0) {
hits = EMPTY; hits = EMPTY;
@ -127,6 +140,7 @@ public class InternalSearchHits implements SearchHits {
@Override public void writeTo(StreamOutput out) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(totalHits); out.writeVLong(totalHits);
out.writeFloat(maxScore);
out.writeVInt(hits.length); out.writeVInt(hits.length);
if (hits.length > 0) { if (hits.length > 0) {
// write the header search shard targets (we assume identity equality) // write the header search shard targets (we assume identity equality)

View File

@ -33,7 +33,7 @@ import java.io.IOException;
import static org.elasticsearch.search.internal.InternalSearchHits.*; import static org.elasticsearch.search.internal.InternalSearchHits.*;
/** /**
* @author kimchy (Shay Banon) * @author kimchy (shay.banon)
*/ */
public class InternalSearchResponse implements Streamable, ToXContent { public class InternalSearchResponse implements Streamable, ToXContent {

View File

@ -88,9 +88,10 @@ public class TransportTwoServersSearchTests extends AbstractNodesTests {
assertThat(searchResponse.hits().totalHits(), equalTo(100l)); assertThat(searchResponse.hits().totalHits(), equalTo(100l));
assertThat(searchResponse.hits().hits().length, equalTo(60)); assertThat(searchResponse.hits().hits().length, equalTo(60));
// System.out.println("max_score: " + searchResponse.hits().maxScore());
for (int i = 0; i < 60; i++) { for (int i = 0; i < 60; i++) {
SearchHit hit = searchResponse.hits().hits()[i]; SearchHit hit = searchResponse.hits().hits()[i];
// System.out.println(hit.shard() + ": " + hit.explanation()); // System.out.println(hit.shard() + ": " + hit.score() + ":" + hit.explanation());
assertThat(hit.explanation(), notNullValue()); assertThat(hit.explanation(), notNullValue());
assertThat("id[" + hit.id() + "]", hit.id(), equalTo(Integer.toString(100 - i - 1))); assertThat("id[" + hit.id() + "]", hit.id(), equalTo(Integer.toString(100 - i - 1)));
} }