Parse EvluationMetrics as named Objects

This commit is contained in:
Christoph Büscher 2017-11-16 14:06:10 +01:00 committed by Christoph Büscher
parent 504462f617
commit 35fabdaf8a
6 changed files with 40 additions and 58 deletions

View File

@ -19,16 +19,12 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
@ -58,35 +54,6 @@ public interface EvaluationMetric extends ToXContent, NamedWriteable {
*/
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
EvaluationMetric rc;
Token token = parser.nextToken();
if (token != XContentParser.Token.FIELD_NAME) {
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
}
String metricName = parser.currentName();
// TODO switch to using a plugable registry
switch (metricName) {
case PrecisionAtK.NAME:
rc = PrecisionAtK.fromXContent(parser);
break;
case MeanReciprocalRank.NAME:
rc = MeanReciprocalRank.fromXContent(parser);
break;
case DiscountedCumulativeGain.NAME:
rc = DiscountedCumulativeGain.fromXContent(parser);
break;
default:
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);
}
if (parser.currentToken() == XContentParser.Token.END_OBJECT) {
// if we are at END_OBJECT, move to the next one...
parser.nextToken();
}
return rc;
}
/**
* join hits with rated documents using the joint _index/_id document key
*/

View File

@ -23,11 +23,13 @@ import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
@ -64,4 +66,16 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
return namedWriteables;
}
@Override
public List<NamedXContentRegistry.Entry> getNamedXContent() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain::fromXContent));
return namedXContent;
}
}

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.script.Script;
import java.io.IOException;
@ -169,22 +170,21 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
return RatedRequest.fromXContent(p);
}, REQUESTS_FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
try {
return EvaluationMetric.fromXContent(p);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
}
}, METRIC_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
return ScriptWithId.fromXContent(p);
}, TEMPLATES_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedRequest.fromXContent(p), REQUESTS_FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseMetric(p), METRIC_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ScriptWithId.fromXContent(p),
TEMPLATES_FIELD);
PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD);
}
private static EvaluationMetric parseMetric(XContentParser parser) throws IOException {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
EvaluationMetric metric = parser.namedObject(EvaluationMetric.class, parser.currentName(), null);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation);
return metric;
}
public static RankEvalSpec parse(XContentParser parser) {
return PARSER.apply(parser, null);
}

View File

@ -166,7 +166,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
return hits;
}
private static MeanReciprocalRank createTestItem() {
static MeanReciprocalRank createTestItem() {
return new MeanReciprocalRank(randomIntBetween(0, 20));
}

View File

@ -20,6 +20,7 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
@ -49,6 +50,12 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
public class RankEvalSpecTests extends ESTestCase {
@SuppressWarnings("resource")
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new RankEvalPlugin().getNamedXContent());
}
private static <T> List<T> randomList(Supplier<T> randomSupplier) {
List<T> result = new ArrayList<>();
int size = randomIntBetween(1, 20);
@ -59,12 +66,10 @@ public class RankEvalSpecTests extends ESTestCase {
}
private static RankEvalSpec createTestItem() throws IOException {
EvaluationMetric metric;
if (randomBoolean()) {
metric = PrecisionAtKTests.createTestItem();
} else {
metric = DiscountedCumulativeGainTests.createTestItem();
}
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
() -> PrecisionAtKTests.createTestItem(),
() -> MeanReciprocalRankTests.createTestItem(),
() -> DiscountedCumulativeGainTests.createTestItem()));
List<RatedRequest> ratedRequests = null;
Collection<ScriptWithId> templates = null;
@ -92,7 +97,7 @@ public class RankEvalSpecTests extends ESTestCase {
new SearchSourceBuilder());
ratedRequests = Arrays.asList(ratedRequest);
}
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric, templates);
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric.get(), templates);
maybeSet(spec::setMaxConcurrentSearches, randomInt(100));
List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20);
@ -105,7 +110,6 @@ public class RankEvalSpecTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException {
RankEvalSpec testItem = createTestItem();
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) {

View File

@ -53,9 +53,6 @@ public class RatedRequestsTests extends ESTestCase {
private static NamedXContentRegistry xContentRegistry;
/**
* setup for the whole base test class
*/
@BeforeClass
public static void init() {
xContentRegistry = new NamedXContentRegistry(