Rank feature - unknown field linear (#983)

Signed-off-by: Yevhen Tienkaiev <hronom@gmail.com>
This commit is contained in:
Yevhen Tienkaiev 2021-07-29 21:28:29 +03:00 committed by GitHub
parent 3744f40c28
commit f46f6950ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 2 deletions

View File

@ -493,4 +493,10 @@ public class QueryDSLDocumentationTests extends OpenSearchTestCase {
0.6f 0.6f
); );
} }
public void testRankFeatureLinear() {
RankFeatureQueryBuilders.linear(
"pagerank"
);
}
} }

View File

@ -39,6 +39,7 @@ import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ConstructingObjectParser; import org.opensearch.common.xcontent.ConstructingObjectParser;
import org.opensearch.common.xcontent.ObjectParser;
import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.index.mapper.RankFeatureFieldMapper.RankFeatureFieldType; import org.opensearch.index.mapper.RankFeatureFieldMapper.RankFeatureFieldType;
import org.opensearch.index.mapper.RankFeatureMetaFieldMapper; import org.opensearch.index.mapper.RankFeatureMetaFieldMapper;
@ -257,6 +258,51 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
return FeatureField.newSigmoidQuery(field, feature, DEFAULT_BOOST, pivot, exp); return FeatureField.newSigmoidQuery(field, feature, DEFAULT_BOOST, pivot, exp);
} }
} }
/**
* A scoring function that scores documents as simply {@code S}
* where S is the indexed value of the static feature.
*/
public static class Linear extends ScoreFunction {
private static final ObjectParser<Linear, Void> PARSER = new ObjectParser<>("linear", Linear::new);
public Linear() {
}
private Linear(StreamInput in) {
this();
}
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
return true;
}
@Override
public int hashCode() {
return getClass().hashCode();
}
@Override
void writeTo(StreamOutput out) throws IOException {
out.writeByte((byte) 3);
}
@Override
void doXContent(XContentBuilder builder) throws IOException {
builder.startObject("linear");
builder.endObject();
}
@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
return FeatureField.newLinearQuery(field, feature, DEFAULT_BOOST);
}
}
} }
private static ScoreFunction readScoreFunction(StreamInput in) throws IOException { private static ScoreFunction readScoreFunction(StreamInput in) throws IOException {
@ -268,6 +314,8 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
return new ScoreFunction.Saturation(in); return new ScoreFunction.Saturation(in);
case 2: case 2:
return new ScoreFunction.Sigmoid(in); return new ScoreFunction.Sigmoid(in);
case 3:
return new ScoreFunction.Linear(in);
default: default:
throw new IOException("Illegal score function id: " + b); throw new IOException("Illegal score function id: " + b);
} }
@ -281,7 +329,7 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
long numNonNulls = Arrays.stream(args, 3, args.length).filter(Objects::nonNull).count(); long numNonNulls = Arrays.stream(args, 3, args.length).filter(Objects::nonNull).count();
final RankFeatureQueryBuilder query; final RankFeatureQueryBuilder query;
if (numNonNulls > 1) { if (numNonNulls > 1) {
throw new IllegalArgumentException("Can only specify one of [log], [saturation] and [sigmoid]"); throw new IllegalArgumentException("Can only specify one of [log], [saturation], [sigmoid] and [linear]");
} else if (numNonNulls == 0) { } else if (numNonNulls == 0) {
query = new RankFeatureQueryBuilder(field, new ScoreFunction.Saturation()); query = new RankFeatureQueryBuilder(field, new ScoreFunction.Saturation());
} else { } else {
@ -305,6 +353,8 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
ScoreFunction.Saturation.PARSER, new ParseField("saturation")); ScoreFunction.Saturation.PARSER, new ParseField("saturation"));
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ScoreFunction.Sigmoid.PARSER, new ParseField("sigmoid")); ScoreFunction.Sigmoid.PARSER, new ParseField("sigmoid"));
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ScoreFunction.Linear.PARSER, new ParseField("linear"));
} }
public static final String NAME = "rank_feature"; public static final String NAME = "rank_feature";

View File

@ -77,4 +77,13 @@ public final class RankFeatureQueryBuilders {
return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Sigmoid(pivot, exp)); return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Sigmoid(pivot, exp));
} }
/**
* Return a new {@link RankFeatureQueryBuilder} that will score documents as
* {@code S)} where S is the indexed value of the static feature.
* @param fieldName field that stores features
*/
public static RankFeatureQueryBuilder linear(String fieldName) {
return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Linear());
}
} }

View File

@ -73,7 +73,7 @@ public class RankFeatureQueryBuilderTests extends AbstractQueryTestCase<RankFeat
protected RankFeatureQueryBuilder doCreateTestQueryBuilder() { protected RankFeatureQueryBuilder doCreateTestQueryBuilder() {
ScoreFunction function; ScoreFunction function;
boolean mayUseNegativeField = true; boolean mayUseNegativeField = true;
switch (random().nextInt(3)) { switch (random().nextInt(4)) {
case 0: case 0:
mayUseNegativeField = false; mayUseNegativeField = false;
function = new ScoreFunction.Log(1 + randomFloat()); function = new ScoreFunction.Log(1 + randomFloat());
@ -88,6 +88,9 @@ public class RankFeatureQueryBuilderTests extends AbstractQueryTestCase<RankFeat
case 2: case 2:
function = new ScoreFunction.Sigmoid(randomFloat(), randomFloat()); function = new ScoreFunction.Sigmoid(randomFloat(), randomFloat());
break; break;
case 3:
function = new ScoreFunction.Linear();
break;
default: default:
throw new AssertionError(); throw new AssertionError();
} }