diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java index 20760b28f1e..9d58d47847f 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java @@ -26,8 +26,11 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * For each qa specification identified by its id this response returns the respective @@ -80,16 +83,33 @@ public class RankEvalResponse extends ActionResponse implements ToXContent { super.writeTo(out); out.writeString(specId); out.writeDouble(qualityLevel); - out.writeGenericValue(getUnknownDocs()); + out.writeVInt(unknownDocs.size()); + for (String queryId : unknownDocs.keySet()) { + out.writeString(queryId); + Collection collection = unknownDocs.get(queryId); + out.writeVInt(collection.size()); + for (RatedDocumentKey key : collection) { + key.writeTo(out); + } + } } @Override - @SuppressWarnings("unchecked") public void readFrom(StreamInput in) throws IOException { super.readFrom(in); this.specId = in.readString(); this.qualityLevel = in.readDouble(); - this.unknownDocs = (Map>) in.readGenericValue(); + int unknownDocumentSets = in.readVInt(); + this.unknownDocs = new HashMap<>(unknownDocumentSets); + for (int i = 0; i < unknownDocumentSets; i++) { + String queryId = in.readString(); + int numberUnknownDocs = in.readVInt(); + Collection collection = new ArrayList<>(numberUnknownDocs); + for (int d = 0; d < numberUnknownDocs; d++) { + collection.add(new RatedDocumentKey(in)); + } + this.unknownDocs.put(queryId, collection); + } } @Override @@ -107,4 +127,23 @@ public class RankEvalResponse extends ActionResponse implements ToXContent { builder.endObject(); return builder; } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + RankEvalResponse other = (RankEvalResponse) obj; + return Objects.equals(specId, other.specId) && + Objects.equals(qualityLevel, other.qualityLevel) && + Objects.equals(unknownDocs, other.unknownDocs); + } + + @Override + public final int hashCode() { + return Objects.hash(getClass(), specId, qualityLevel, unknownDocs); + } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index 0de68b2c1b9..3e4a8d0ea19 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -34,7 +34,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.Collection; -import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -64,7 +63,7 @@ public class TransportRankEvalAction extends HandledTransportAction listener) { RankEvalSpec qualityTask = request.getRankEvalSpec(); - Map> unknownDocs = new HashMap<>(); + Map> unknownDocs = new ConcurrentHashMap<>(); Collection specifications = qualityTask.getSpecifications(); AtomicInteger responseCounter = new AtomicInteger(specifications.size()); Map partialResults = new ConcurrentHashMap<>(specifications.size()); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java new file mode 100644 index 00000000000..89e63b0b6f0 --- /dev/null +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class RankEvalResponseTests extends ESTestCase { + + private static RankEvalResponse createRandomResponse() { + Map> unknownDocs = new HashMap<>(); + int numberOfSets = randomIntBetween(0, 5); + for (int i = 0; i < numberOfSets; i++) { + List ids = new ArrayList<>(); + int numberOfUnknownDocs = randomIntBetween(0, 5); + for (int d = 0; d < numberOfUnknownDocs; d++) { + ids.add(new RatedDocumentKey(randomAsciiOfLength(5), randomAsciiOfLength(5), randomAsciiOfLength(5))); + } + unknownDocs.put(randomAsciiOfLength(5), ids); + } + return new RankEvalResponse(randomAsciiOfLengthBetween(1, 10), randomDouble(), unknownDocs ); + } + + public void testSerialization() throws IOException { + RankEvalResponse randomResponse = createRandomResponse(); + try (BytesStreamOutput output = new BytesStreamOutput()) { + randomResponse.writeTo(output); + try (StreamInput in = output.bytes().streamInput()) { + RankEvalResponse deserializedResponse = new RankEvalResponse(); + deserializedResponse.readFrom(in); + assertEquals(randomResponse, deserializedResponse); + assertEquals(randomResponse.hashCode(), deserializedResponse.hashCode()); + assertNotSame(randomResponse, deserializedResponse); + assertEquals(-1, in.read()); + } + } + } + +}