Parse EvluationMetrics as named Objects
This commit is contained in:
parent
504462f617
commit
35fabdaf8a
|
@ -19,16 +19,12 @@
|
||||||
|
|
||||||
package org.elasticsearch.index.rankeval;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.common.ParsingException;
|
|
||||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser.Token;
|
|
||||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.search.SearchHits;
|
import org.elasticsearch.search.SearchHits;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -58,35 +54,6 @@ public interface EvaluationMetric extends ToXContent, NamedWriteable {
|
||||||
*/
|
*/
|
||||||
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
|
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
|
||||||
|
|
||||||
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
|
|
||||||
EvaluationMetric rc;
|
|
||||||
Token token = parser.nextToken();
|
|
||||||
if (token != XContentParser.Token.FIELD_NAME) {
|
|
||||||
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
|
|
||||||
}
|
|
||||||
String metricName = parser.currentName();
|
|
||||||
|
|
||||||
// TODO switch to using a plugable registry
|
|
||||||
switch (metricName) {
|
|
||||||
case PrecisionAtK.NAME:
|
|
||||||
rc = PrecisionAtK.fromXContent(parser);
|
|
||||||
break;
|
|
||||||
case MeanReciprocalRank.NAME:
|
|
||||||
rc = MeanReciprocalRank.fromXContent(parser);
|
|
||||||
break;
|
|
||||||
case DiscountedCumulativeGain.NAME:
|
|
||||||
rc = DiscountedCumulativeGain.fromXContent(parser);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);
|
|
||||||
}
|
|
||||||
if (parser.currentToken() == XContentParser.Token.END_OBJECT) {
|
|
||||||
// if we are at END_OBJECT, move to the next one...
|
|
||||||
parser.nextToken();
|
|
||||||
}
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* join hits with rated documents using the joint _index/_id document key
|
* join hits with rated documents using the joint _index/_id document key
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -23,11 +23,13 @@ import org.elasticsearch.action.ActionRequest;
|
||||||
import org.elasticsearch.action.ActionResponse;
|
import org.elasticsearch.action.ActionResponse;
|
||||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.settings.ClusterSettings;
|
import org.elasticsearch.common.settings.ClusterSettings;
|
||||||
import org.elasticsearch.common.settings.IndexScopedSettings;
|
import org.elasticsearch.common.settings.IndexScopedSettings;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.settings.SettingsFilter;
|
import org.elasticsearch.common.settings.SettingsFilter;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.plugins.ActionPlugin;
|
import org.elasticsearch.plugins.ActionPlugin;
|
||||||
import org.elasticsearch.plugins.Plugin;
|
import org.elasticsearch.plugins.Plugin;
|
||||||
import org.elasticsearch.rest.RestController;
|
import org.elasticsearch.rest.RestController;
|
||||||
|
@ -64,4 +66,16 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
|
||||||
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
|
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
|
||||||
return namedWriteables;
|
return namedWriteables;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<NamedXContentRegistry.Entry> getNamedXContent() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.add(
|
||||||
|
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
|
||||||
|
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
|
||||||
|
MeanReciprocalRank::fromXContent));
|
||||||
|
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
|
||||||
|
DiscountedCumulativeGain::fromXContent));
|
||||||
|
return namedXContent;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParserUtils;
|
||||||
import org.elasticsearch.script.Script;
|
import org.elasticsearch.script.Script;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -169,22 +170,21 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||||
a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
|
a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedRequest.fromXContent(p), REQUESTS_FIELD);
|
||||||
return RatedRequest.fromXContent(p);
|
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseMetric(p), METRIC_FIELD);
|
||||||
}, REQUESTS_FIELD);
|
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ScriptWithId.fromXContent(p),
|
||||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
TEMPLATES_FIELD);
|
||||||
try {
|
|
||||||
return EvaluationMetric.fromXContent(p);
|
|
||||||
} catch (IOException ex) {
|
|
||||||
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
|
|
||||||
}
|
|
||||||
}, METRIC_FIELD);
|
|
||||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
|
|
||||||
return ScriptWithId.fromXContent(p);
|
|
||||||
}, TEMPLATES_FIELD);
|
|
||||||
PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD);
|
PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static EvaluationMetric parseMetric(XContentParser parser) throws IOException {
|
||||||
|
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation);
|
||||||
|
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
|
||||||
|
EvaluationMetric metric = parser.namedObject(EvaluationMetric.class, parser.currentName(), null);
|
||||||
|
XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation);
|
||||||
|
return metric;
|
||||||
|
}
|
||||||
|
|
||||||
public static RankEvalSpec parse(XContentParser parser) {
|
public static RankEvalSpec parse(XContentParser parser) {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,7 +166,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
||||||
return hits;
|
return hits;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MeanReciprocalRank createTestItem() {
|
static MeanReciprocalRank createTestItem() {
|
||||||
return new MeanReciprocalRank(randomIntBetween(0, 20));
|
return new MeanReciprocalRank(randomIntBetween(0, 20));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
package org.elasticsearch.index.rankeval;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||||
|
@ -49,6 +50,12 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
|
||||||
|
|
||||||
public class RankEvalSpecTests extends ESTestCase {
|
public class RankEvalSpecTests extends ESTestCase {
|
||||||
|
|
||||||
|
@SuppressWarnings("resource")
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
return new NamedXContentRegistry(new RankEvalPlugin().getNamedXContent());
|
||||||
|
}
|
||||||
|
|
||||||
private static <T> List<T> randomList(Supplier<T> randomSupplier) {
|
private static <T> List<T> randomList(Supplier<T> randomSupplier) {
|
||||||
List<T> result = new ArrayList<>();
|
List<T> result = new ArrayList<>();
|
||||||
int size = randomIntBetween(1, 20);
|
int size = randomIntBetween(1, 20);
|
||||||
|
@ -59,12 +66,10 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static RankEvalSpec createTestItem() throws IOException {
|
private static RankEvalSpec createTestItem() throws IOException {
|
||||||
EvaluationMetric metric;
|
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
|
||||||
if (randomBoolean()) {
|
() -> PrecisionAtKTests.createTestItem(),
|
||||||
metric = PrecisionAtKTests.createTestItem();
|
() -> MeanReciprocalRankTests.createTestItem(),
|
||||||
} else {
|
() -> DiscountedCumulativeGainTests.createTestItem()));
|
||||||
metric = DiscountedCumulativeGainTests.createTestItem();
|
|
||||||
}
|
|
||||||
|
|
||||||
List<RatedRequest> ratedRequests = null;
|
List<RatedRequest> ratedRequests = null;
|
||||||
Collection<ScriptWithId> templates = null;
|
Collection<ScriptWithId> templates = null;
|
||||||
|
@ -92,7 +97,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||||
new SearchSourceBuilder());
|
new SearchSourceBuilder());
|
||||||
ratedRequests = Arrays.asList(ratedRequest);
|
ratedRequests = Arrays.asList(ratedRequest);
|
||||||
}
|
}
|
||||||
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric, templates);
|
RankEvalSpec spec = new RankEvalSpec(ratedRequests, metric.get(), templates);
|
||||||
maybeSet(spec::setMaxConcurrentSearches, randomInt(100));
|
maybeSet(spec::setMaxConcurrentSearches, randomInt(100));
|
||||||
List<String> indices = new ArrayList<>();
|
List<String> indices = new ArrayList<>();
|
||||||
int size = randomIntBetween(0, 20);
|
int size = randomIntBetween(0, 20);
|
||||||
|
@ -105,7 +110,6 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||||
|
|
||||||
public void testXContentRoundtrip() throws IOException {
|
public void testXContentRoundtrip() throws IOException {
|
||||||
RankEvalSpec testItem = createTestItem();
|
RankEvalSpec testItem = createTestItem();
|
||||||
|
|
||||||
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
|
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
|
||||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) {
|
try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) {
|
||||||
|
|
||||||
|
|
|
@ -53,9 +53,6 @@ public class RatedRequestsTests extends ESTestCase {
|
||||||
|
|
||||||
private static NamedXContentRegistry xContentRegistry;
|
private static NamedXContentRegistry xContentRegistry;
|
||||||
|
|
||||||
/**
|
|
||||||
* setup for the whole base test class
|
|
||||||
*/
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void init() {
|
public static void init() {
|
||||||
xContentRegistry = new NamedXContentRegistry(
|
xContentRegistry = new NamedXContentRegistry(
|
||||||
|
|
Loading…
Reference in New Issue