Parse EvluationMetrics as named Objects
This commit is contained in:
parent
504462f617
commit
35fabdaf8a
|
@ -19,16 +19,12 @@
|
|||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser.Token;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
@ -58,35 +54,6 @@ public interface EvaluationMetric extends ToXContent, NamedWriteable {
|
|||
*/
|
||||
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
|
||||
|
||||
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
|
||||
EvaluationMetric rc;
|
||||
Token token = parser.nextToken();
|
||||
if (token != XContentParser.Token.FIELD_NAME) {
|
||||
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
|
||||
}
|
||||
String metricName = parser.currentName();
|
||||
|
||||
// TODO switch to using a plugable registry
|
||||
switch (metricName) {
|
||||
case PrecisionAtK.NAME:
|
||||
rc = PrecisionAtK.fromXContent(parser);
|
||||
break;
|
||||
case MeanReciprocalRank.NAME:
|
||||
rc = MeanReciprocalRank.fromXContent(parser);
|
||||
break;
|
||||
case DiscountedCumulativeGain.NAME:
|
||||
rc = DiscountedCumulativeGain.fromXContent(parser);
|
||||
break;
|
||||
default:
|
||||
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);
|
||||
}
|
||||
if (parser.currentToken() == XContentParser.Token.END_OBJECT) {
|
||||
// if we are at END_OBJECT, move to the next one...
|
||||
parser.nextToken();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
/**
|
||||
* join hits with rated documents using the joint _index/_id document key
|
||||
*/
|
||||
|
|
|
@ -23,11 +23,13 @@ import org.elasticsearch.action.ActionRequest;
|
|||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.IndexScopedSettings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.settings.SettingsFilter;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.ActionPlugin;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
|
@ -64,4 +66,16 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
|
|||
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
|
||||
return namedWriteables;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContent() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.add(
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
|
||||
MeanReciprocalRank::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
|
||||
DiscountedCumulativeGain::fromXContent));
|
||||
return namedXContent;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParserUtils;
|
||||
import org.elasticsearch.script.Script;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -169,22 +170,21 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
|||
a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
return RatedRequest.fromXContent(p);
|
||||
}, REQUESTS_FIELD);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
try {
|
||||
return EvaluationMetric.fromXContent(p);
|
||||
} catch (IOException ex) {
|
||||
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
|
||||
}
|
||||
}, METRIC_FIELD);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
|
||||
return ScriptWithId.fromXContent(p);
|
||||
}, TEMPLATES_FIELD);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedRequest.fromXContent(p), REQUESTS_FIELD);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseMetric(p), METRIC_FIELD);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ScriptWithId.fromXContent(p),
|
||||
TEMPLATES_FIELD);
|
||||
PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD);
|
||||
}
|
||||
|
||||
private static EvaluationMetric parseMetric(XContentParser parser) throws IOException {
|
||||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation);
|
||||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
|
||||
EvaluationMetric metric = parser.namedObject(EvaluationMetric.class, parser.currentName(), null);
|
||||
XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation);
|
||||
return metric;
|
||||
}
|
||||
|
||||
public static RankEvalSpec parse(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
|
|
@ -166,7 +166,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
|||
return hits;
|
||||
}
|
||||
|
||||
private static MeanReciprocalRank createTestItem() {
|
||||
static MeanReciprocalRank createTestItem() {
|
||||
return new MeanReciprocalRank(randomIntBetween(0, 20));
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
|
@ -49,6 +50,12 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
|
|||
|
||||
public class RankEvalSpecTests extends ESTestCase {
|
||||
|
||||
@SuppressWarnings("resource")
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new RankEvalPlugin().getNamedXContent());
|
||||
}
|
||||
|
||||
private static <T> List<T> randomList(Supplier<T> randomSupplier) {
|
||||
List<T> result = new ArrayList<>();
|
||||
int size = randomIntBetween(1, 20);
|
||||
|
@ -59,12 +66,10 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static RankEvalSpec createTestItem() throws IOException {
|
||||
EvaluationMetric metric;
|
||||
if (randomBoolean()) {
|
||||
metric = PrecisionAtKTests.createTestItem();
|
||||
} else {
|
||||
metric = DiscountedCumulativeGainTests.createTestItem();
|
||||
}
|
||||
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
|
||||
() -> PrecisionAtKTests.createTestItem(),
|
||||
() -> MeanReciprocalRankTests.createTestItem(),
|
||||
() -> DiscountedCumulativeGainTests.createTestItem()));
|
||||
|
||||
List<RatedRequest> ratedRequests = null;
|
||||
Collection<ScriptWithId> templates = null;
|
||||
|
@ -92,7 +97,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
new SearchSourceBuilder());
|
||||
ratedRequests = Arrays.asList(ratedRequest);
|
||||
}
|
||||
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric, templates);
|
||||
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric.get(), templates);
|
||||
maybeSet(spec::setMaxConcurrentSearches, randomInt(100));
|
||||
List<String> indices = new ArrayList<>();
|
||||
int size = randomIntBetween(0, 20);
|
||||
|
@ -105,7 +110,6 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
RankEvalSpec testItem = createTestItem();
|
||||
|
||||
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) {
|
||||
|
||||
|
|
|
@ -53,9 +53,6 @@ public class RatedRequestsTests extends ESTestCase {
|
|||
|
||||
private static NamedXContentRegistry xContentRegistry;
|
||||
|
||||
/**
|
||||
* setup for the whole base test class
|
||||
*/
|
||||
@BeforeClass
|
||||
public static void init() {
|
||||
xContentRegistry = new NamedXContentRegistry(
|
||||
|
|
Loading…
Reference in New Issue