[ML] Partition-wise maximum scores (#32748)

Added infrastructure to push through the 'person name field value' to
the normalizer process. This is required by the normalizer to retrieve
the maximum scores for individual partitions.
This commit is contained in:
Ed Savage 2018-08-13 10:31:17 +01:00 committed by GitHub
parent 4d20e69b83
commit d147cd72cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 74 additions and 5 deletions

View File

@ -46,6 +46,11 @@ class BucketInfluencerNormalizable extends AbstractLeafNormalizable {
return bucketInfluencer.getInfluencerFieldName();
}
@Override
public String getPersonFieldValue() {
return null;
}
@Override
public String getFunctionName() {
return null;

View File

@ -64,6 +64,11 @@ public class BucketNormalizable extends Normalizable {
return null;
}
@Override
public String getPersonFieldValue() {
return null;
}
@Override
public String getFunctionName() {
return null;

View File

@ -44,6 +44,11 @@ class InfluencerNormalizable extends AbstractLeafNormalizable {
return influencer.getInfluencerFieldName();
}
@Override
public String getPersonFieldValue() {
return influencer.getInfluencerFieldValue();
}
@Override
public String getFunctionName() {
return null;

View File

@ -63,10 +63,11 @@ public class MultiplyingNormalizerProcess implements NormalizerProcess {
result.setPartitionFieldName(record[1]);
result.setPartitionFieldValue(record[2]);
result.setPersonFieldName(record[3]);
result.setFunctionName(record[4]);
result.setValueFieldName(record[5]);
result.setProbability(Double.parseDouble(record[6]));
result.setNormalizedScore(factor * Double.parseDouble(record[7]));
result.setPersonFieldValue(record[4]);
result.setFunctionName(record[5]);
result.setValueFieldName(record[6]);
result.setProbability(Double.parseDouble(record[7]));
result.setNormalizedScore(factor * Double.parseDouble(record[8]));
} catch (NumberFormatException | ArrayIndexOutOfBoundsException e) {
throw new IOException("Unable to write to no-op normalizer", e);
}

View File

@ -44,6 +44,8 @@ public abstract class Normalizable implements ToXContentObject {
abstract String getPersonFieldName();
abstract String getPersonFieldValue();
abstract String getFunctionName();
abstract String getValueFieldName();

View File

@ -70,6 +70,7 @@ public class Normalizer {
NormalizerResult.PARTITION_FIELD_NAME_FIELD.getPreferredName(),
NormalizerResult.PARTITION_FIELD_VALUE_FIELD.getPreferredName(),
NormalizerResult.PERSON_FIELD_NAME_FIELD.getPreferredName(),
NormalizerResult.PERSON_FIELD_VALUE_FIELD.getPreferredName(),
NormalizerResult.FUNCTION_NAME_FIELD.getPreferredName(),
NormalizerResult.VALUE_FIELD_NAME_FIELD.getPreferredName(),
NormalizerResult.PROBABILITY_FIELD.getPreferredName(),
@ -108,6 +109,7 @@ public class Normalizer {
Strings.coalesceToEmpty(normalizable.getPartitionFieldName()),
Strings.coalesceToEmpty(normalizable.getPartitionFieldValue()),
Strings.coalesceToEmpty(normalizable.getPersonFieldName()),
Strings.coalesceToEmpty(normalizable.getPersonFieldValue()),
Strings.coalesceToEmpty(normalizable.getFunctionName()),
Strings.coalesceToEmpty(normalizable.getValueFieldName()),
Double.toString(normalizable.getProbability()),

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.job.process.normalizer;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -26,6 +27,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
static final ParseField PARTITION_FIELD_NAME_FIELD = new ParseField("partition_field_name");
static final ParseField PARTITION_FIELD_VALUE_FIELD = new ParseField("partition_field_value");
static final ParseField PERSON_FIELD_NAME_FIELD = new ParseField("person_field_name");
static final ParseField PERSON_FIELD_VALUE_FIELD = new ParseField("person_field_value");
static final ParseField FUNCTION_NAME_FIELD = new ParseField("function_name");
static final ParseField VALUE_FIELD_NAME_FIELD = new ParseField("value_field_name");
static final ParseField PROBABILITY_FIELD = new ParseField("probability");
@ -39,6 +41,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
PARSER.declareString(NormalizerResult::setPartitionFieldName, PARTITION_FIELD_NAME_FIELD);
PARSER.declareString(NormalizerResult::setPartitionFieldValue, PARTITION_FIELD_VALUE_FIELD);
PARSER.declareString(NormalizerResult::setPersonFieldName, PERSON_FIELD_NAME_FIELD);
PARSER.declareString(NormalizerResult::setPersonFieldValue, PERSON_FIELD_VALUE_FIELD);
PARSER.declareString(NormalizerResult::setFunctionName, FUNCTION_NAME_FIELD);
PARSER.declareString(NormalizerResult::setValueFieldName, VALUE_FIELD_NAME_FIELD);
PARSER.declareDouble(NormalizerResult::setProbability, PROBABILITY_FIELD);
@ -49,6 +52,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
private String partitionFieldName;
private String partitionFieldValue;
private String personFieldName;
private String personFieldValue;
private String functionName;
private String valueFieldName;
private double probability;
@ -62,6 +66,9 @@ public class NormalizerResult implements ToXContentObject, Writeable {
partitionFieldName = in.readOptionalString();
partitionFieldValue = in.readOptionalString();
personFieldName = in.readOptionalString();
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
personFieldValue = in.readOptionalString();
}
functionName = in.readOptionalString();
valueFieldName = in.readOptionalString();
probability = in.readDouble();
@ -74,6 +81,9 @@ public class NormalizerResult implements ToXContentObject, Writeable {
out.writeOptionalString(partitionFieldName);
out.writeOptionalString(partitionFieldValue);
out.writeOptionalString(personFieldName);
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
out.writeOptionalString(personFieldValue);
}
out.writeOptionalString(functionName);
out.writeOptionalString(valueFieldName);
out.writeDouble(probability);
@ -87,6 +97,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
builder.field(PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName);
builder.field(PARTITION_FIELD_VALUE_FIELD.getPreferredName(), partitionFieldValue);
builder.field(PERSON_FIELD_NAME_FIELD.getPreferredName(), personFieldName);
builder.field(PERSON_FIELD_VALUE_FIELD.getPreferredName(), personFieldValue);
builder.field(FUNCTION_NAME_FIELD.getPreferredName(), functionName);
builder.field(VALUE_FIELD_NAME_FIELD.getPreferredName(), valueFieldName);
builder.field(PROBABILITY_FIELD.getPreferredName(), probability);
@ -127,6 +138,14 @@ public class NormalizerResult implements ToXContentObject, Writeable {
this.personFieldName = personFieldName;
}
public String getPersonFieldValue() {
return personFieldValue;
}
public void setPersonFieldValue(String personFieldValue) {
this.personFieldValue = personFieldValue;
}
public String getFunctionName() {
return functionName;
}
@ -161,7 +180,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
@Override
public int hashCode() {
return Objects.hash(level, partitionFieldName, partitionFieldValue, personFieldName,
return Objects.hash(level, partitionFieldName, partitionFieldValue, personFieldName, personFieldValue,
functionName, valueFieldName, probability, normalizedScore);
}
@ -184,6 +203,7 @@ public class NormalizerResult implements ToXContentObject, Writeable {
&& Objects.equals(this.partitionFieldName, that.partitionFieldName)
&& Objects.equals(this.partitionFieldValue, that.partitionFieldValue)
&& Objects.equals(this.personFieldName, that.personFieldName)
&& Objects.equals(this.personFieldValue, that.personFieldValue)
&& Objects.equals(this.functionName, that.functionName)
&& Objects.equals(this.valueFieldName, that.valueFieldName)
&& this.probability == that.probability

View File

@ -45,6 +45,11 @@ public class PartitionScoreNormalizable extends AbstractLeafNormalizable {
return null;
}
@Override
public String getPersonFieldValue() {
return null;
}
@Override
public String getFunctionName() {
return null;

View File

@ -46,6 +46,12 @@ class RecordNormalizable extends AbstractLeafNormalizable {
return over != null ? over : record.getByFieldName();
}
@Override
public String getPersonFieldValue() {
String over = record.getOverFieldValue();
return over != null ? over : record.getByFieldValue();
}
@Override
public String getFunctionName() {
return record.getFunction();

View File

@ -43,10 +43,18 @@ public class BucketInfluencerNormalizableTests extends ESTestCase {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPartitionFieldName());
}
public void testGetPartitionFieldValue() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPartitionFieldValue());
}
public void testGetPersonFieldName() {
assertEquals("airline", new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPersonFieldName());
}
public void testGetPersonFieldValue() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getPersonFieldValue());
}
public void testGetFunctionName() {
assertNull(new BucketInfluencerNormalizable(bucketInfluencer, INDEX_NAME).getFunctionName());
}

View File

@ -73,6 +73,10 @@ public class BucketNormalizableTests extends ESTestCase {
assertNull(new BucketNormalizable(bucket, INDEX_NAME).getPersonFieldName());
}
public void testGetPersonFieldValue() {
assertNull(new BucketNormalizable(bucket, INDEX_NAME).getPersonFieldValue());
}
public void testGetFunctionName() {
assertNull(new BucketNormalizable(bucket, INDEX_NAME).getFunctionName());
}

View File

@ -44,6 +44,10 @@ public class InfluencerNormalizableTests extends ESTestCase {
assertEquals("airline", new InfluencerNormalizable(influencer, INDEX_NAME).getPersonFieldName());
}
public void testGetPersonFieldValue() {
assertEquals("AAL", new InfluencerNormalizable(influencer, INDEX_NAME).getPersonFieldValue());
}
public void testGetFunctionName() {
assertNull(new InfluencerNormalizable(influencer, INDEX_NAME).getFunctionName());
}

View File

@ -19,6 +19,7 @@ public class NormalizerResultTests extends AbstractSerializingTestCase<Normalize
assertNull(msg.getPartitionFieldName());
assertNull(msg.getPartitionFieldValue());
assertNull(msg.getPersonFieldName());
assertNull(msg.getPersonFieldValue());
assertNull(msg.getFunctionName());
assertNull(msg.getValueFieldName());
assertEquals(0.0, msg.getProbability(), EPSILON);
@ -32,6 +33,7 @@ public class NormalizerResultTests extends AbstractSerializingTestCase<Normalize
msg.setPartitionFieldName("part");
msg.setPartitionFieldValue("something");
msg.setPersonFieldName("person");
msg.setPersonFieldValue("fred");
msg.setFunctionName("mean");
msg.setValueFieldName("value");
msg.setProbability(0.005);