Parse EvluationMetrics as named Objects

This commit is contained in:
Christoph Büscher 2017-11-16 14:06:10 +01:00 committed by Christoph Büscher
parent 504462f617
commit 35fabdaf8a
6 changed files with 40 additions and 58 deletions

View File

@ -19,16 +19,12 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
@ -58,35 +54,6 @@ public interface EvaluationMetric extends ToXContent, NamedWriteable {
*/ */
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs); EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
EvaluationMetric rc;
Token token = parser.nextToken();
if (token != XContentParser.Token.FIELD_NAME) {
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
}
String metricName = parser.currentName();
// TODO switch to using a plugable registry
switch (metricName) {
case PrecisionAtK.NAME:
rc = PrecisionAtK.fromXContent(parser);
break;
case MeanReciprocalRank.NAME:
rc = MeanReciprocalRank.fromXContent(parser);
break;
case DiscountedCumulativeGain.NAME:
rc = DiscountedCumulativeGain.fromXContent(parser);
break;
default:
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);
}
if (parser.currentToken() == XContentParser.Token.END_OBJECT) {
// if we are at END_OBJECT, move to the next one...
parser.nextToken();
}
return rc;
}
/** /**
* join hits with rated documents using the joint _index/_id document key * join hits with rated documents using the joint _index/_id document key
*/ */

View File

@ -23,11 +23,13 @@ import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter; import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestController;
@ -64,4 +66,16 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new)); .add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
return namedWriteables; return namedWriteables;
} }
@Override
public List<NamedXContentRegistry.Entry> getNamedXContent() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain::fromXContent));
return namedXContent;
}
} }

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import java.io.IOException; import java.io.IOException;
@ -169,22 +170,21 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2])); a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
static { static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> { PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedRequest.fromXContent(p), REQUESTS_FIELD);
return RatedRequest.fromXContent(p); PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseMetric(p), METRIC_FIELD);
}, REQUESTS_FIELD); PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ScriptWithId.fromXContent(p),
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { TEMPLATES_FIELD);
try {
return EvaluationMetric.fromXContent(p);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
}
}, METRIC_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
return ScriptWithId.fromXContent(p);
}, TEMPLATES_FIELD);
PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD); PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD);
} }
private static EvaluationMetric parseMetric(XContentParser parser) throws IOException {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
EvaluationMetric metric = parser.namedObject(EvaluationMetric.class, parser.currentName(), null);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation);
return metric;
}
public static RankEvalSpec parse(XContentParser parser) { public static RankEvalSpec parse(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }

View File

@ -166,7 +166,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
return hits; return hits;
} }
private static MeanReciprocalRank createTestItem() { static MeanReciprocalRank createTestItem() {
return new MeanReciprocalRank(randomIntBetween(0, 20)); return new MeanReciprocalRank(randomIntBetween(0, 20));
} }

View File

@ -20,6 +20,7 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
@ -49,6 +50,12 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
public class RankEvalSpecTests extends ESTestCase { public class RankEvalSpecTests extends ESTestCase {
@SuppressWarnings("resource")
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new RankEvalPlugin().getNamedXContent());
}
private static <T> List<T> randomList(Supplier<T> randomSupplier) { private static <T> List<T> randomList(Supplier<T> randomSupplier) {
List<T> result = new ArrayList<>(); List<T> result = new ArrayList<>();
int size = randomIntBetween(1, 20); int size = randomIntBetween(1, 20);
@ -59,12 +66,10 @@ public class RankEvalSpecTests extends ESTestCase {
} }
private static RankEvalSpec createTestItem() throws IOException { private static RankEvalSpec createTestItem() throws IOException {
EvaluationMetric metric; Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
if (randomBoolean()) { () -> PrecisionAtKTests.createTestItem(),
metric = PrecisionAtKTests.createTestItem(); () -> MeanReciprocalRankTests.createTestItem(),
} else { () -> DiscountedCumulativeGainTests.createTestItem()));
metric = DiscountedCumulativeGainTests.createTestItem();
}
List<RatedRequest> ratedRequests = null; List<RatedRequest> ratedRequests = null;
Collection<ScriptWithId> templates = null; Collection<ScriptWithId> templates = null;
@ -92,7 +97,7 @@ public class RankEvalSpecTests extends ESTestCase {
new SearchSourceBuilder()); new SearchSourceBuilder());
ratedRequests = Arrays.asList(ratedRequest); ratedRequests = Arrays.asList(ratedRequest);
} }
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric, templates); RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric.get(), templates);
maybeSet(spec::setMaxConcurrentSearches, randomInt(100)); maybeSet(spec::setMaxConcurrentSearches, randomInt(100));
List<String> indices = new ArrayList<>(); List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20); int size = randomIntBetween(0, 20);
@ -105,7 +110,6 @@ public class RankEvalSpecTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
RankEvalSpec testItem = createTestItem(); RankEvalSpec testItem = createTestItem();
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) { try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) {

View File

@ -53,9 +53,6 @@ public class RatedRequestsTests extends ESTestCase {
private static NamedXContentRegistry xContentRegistry; private static NamedXContentRegistry xContentRegistry;
/**
* setup for the whole base test class
*/
@BeforeClass @BeforeClass
public static void init() { public static void init() {
xContentRegistry = new NamedXContentRegistry( xContentRegistry = new NamedXContentRegistry(