mirror of https://github.com/apache/lucene.git
SOLR-11941: Add abstract contrib/ltr AdapterModel.
This commit is contained in:
parent
fa0aa34bdc
commit
2c8bbc8c18
|
@ -161,6 +161,9 @@ New Features
|
||||||
* SOLR-11925: Time Routed Aliases can have their oldest collections automatically deleted via the "router.autoDeleteAge"
|
* SOLR-11925: Time Routed Aliases can have their oldest collections automatically deleted via the "router.autoDeleteAge"
|
||||||
setting. (David Smiley)
|
setting. (David Smiley)
|
||||||
|
|
||||||
|
* SOLR-11941: Add abstract contrib/ltr AdapterModel to facilitate the development of scoring models that delegate
|
||||||
|
scoring to an opaque pre-trained model. (Christine Poerschke)
|
||||||
|
|
||||||
Bug Fixes
|
Bug Fixes
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.model;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.solr.core.SolrResourceLoader;
|
||||||
|
import org.apache.solr.ltr.feature.Feature;
|
||||||
|
import org.apache.solr.ltr.norm.Normalizer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scoring model whose initialization is completed via its
|
||||||
|
* {@link #init(SolrResourceLoader)} method.
|
||||||
|
*/
|
||||||
|
public abstract class AdapterModel extends LTRScoringModel {
|
||||||
|
|
||||||
|
protected SolrResourceLoader solrResourceLoader;
|
||||||
|
|
||||||
|
public AdapterModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
|
||||||
|
List<Feature> allFeatures, Map<String,Object> params) {
|
||||||
|
super(name, features, norms, featureStoreName, allFeatures, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void init(SolrResourceLoader solrResourceLoader) throws ModelException {
|
||||||
|
this.solrResourceLoader = solrResourceLoader;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -23,7 +23,6 @@ import java.util.Map;
|
||||||
|
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.search.Explanation;
|
import org.apache.lucene.search.Explanation;
|
||||||
import org.apache.solr.core.SolrResourceLoader;
|
|
||||||
import org.apache.solr.ltr.feature.Feature;
|
import org.apache.solr.ltr.feature.Feature;
|
||||||
import org.apache.solr.ltr.norm.Normalizer;
|
import org.apache.solr.ltr.norm.Normalizer;
|
||||||
|
|
||||||
|
@ -51,9 +50,8 @@ import org.apache.solr.ltr.norm.Normalizer;
|
||||||
* Also note that if a "store" is configured for the wrapper
|
* Also note that if a "store" is configured for the wrapper
|
||||||
* model then it must match the "store" of the wrapped model.
|
* model then it must match the "store" of the wrapped model.
|
||||||
*/
|
*/
|
||||||
public abstract class WrapperModel extends LTRScoringModel {
|
public abstract class WrapperModel extends AdapterModel {
|
||||||
|
|
||||||
protected SolrResourceLoader solrResourceLoader;
|
|
||||||
protected LTRScoringModel model;
|
protected LTRScoringModel model;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -107,10 +105,6 @@ public abstract class WrapperModel extends LTRScoringModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setSolrResourceLoader(SolrResourceLoader solrResourceLoader) {
|
|
||||||
this.solrResourceLoader = solrResourceLoader;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void updateModel(LTRScoringModel model) {
|
public void updateModel(LTRScoringModel model) {
|
||||||
this.model = model;
|
this.model = model;
|
||||||
validate();
|
validate();
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.solr.common.util.NamedList;
|
||||||
import org.apache.solr.core.SolrCore;
|
import org.apache.solr.core.SolrCore;
|
||||||
import org.apache.solr.core.SolrResourceLoader;
|
import org.apache.solr.core.SolrResourceLoader;
|
||||||
import org.apache.solr.ltr.feature.Feature;
|
import org.apache.solr.ltr.feature.Feature;
|
||||||
|
import org.apache.solr.ltr.model.AdapterModel;
|
||||||
import org.apache.solr.ltr.model.WrapperModel;
|
import org.apache.solr.ltr.model.WrapperModel;
|
||||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||||
import org.apache.solr.ltr.model.ModelException;
|
import org.apache.solr.ltr.model.ModelException;
|
||||||
|
@ -241,25 +242,29 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
||||||
featureStore.getFeatures(),
|
featureStore.getFeatures(),
|
||||||
(Map<String,Object>) modelMap.get(PARAMS_KEY));
|
(Map<String,Object>) modelMap.get(PARAMS_KEY));
|
||||||
|
|
||||||
if (ltrScoringModel instanceof WrapperModel) {
|
if (ltrScoringModel instanceof AdapterModel) {
|
||||||
initWrapperModel(solrResourceLoader, (WrapperModel)ltrScoringModel, managedFeatureStore);
|
initAdapterModel(solrResourceLoader, (AdapterModel)ltrScoringModel, managedFeatureStore);
|
||||||
}
|
}
|
||||||
|
|
||||||
return ltrScoringModel;
|
return ltrScoringModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void initAdapterModel(SolrResourceLoader solrResourceLoader,
|
||||||
|
AdapterModel adapterModel, ManagedFeatureStore managedFeatureStore) {
|
||||||
|
adapterModel.init(solrResourceLoader);
|
||||||
|
if (adapterModel instanceof WrapperModel) {
|
||||||
|
initWrapperModel(solrResourceLoader, (WrapperModel)adapterModel, managedFeatureStore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void initWrapperModel(SolrResourceLoader solrResourceLoader,
|
private static void initWrapperModel(SolrResourceLoader solrResourceLoader,
|
||||||
WrapperModel wrapperModel, ManagedFeatureStore managedFeatureStore) {
|
WrapperModel wrapperModel, ManagedFeatureStore managedFeatureStore) {
|
||||||
wrapperModel.setSolrResourceLoader(solrResourceLoader);
|
|
||||||
final LTRScoringModel model = fromLTRScoringModelMap(
|
final LTRScoringModel model = fromLTRScoringModelMap(
|
||||||
solrResourceLoader,
|
solrResourceLoader,
|
||||||
wrapperModel.fetchModelMap(),
|
wrapperModel.fetchModelMap(),
|
||||||
managedFeatureStore);
|
managedFeatureStore);
|
||||||
if (model instanceof WrapperModel) {
|
if (model instanceof AdapterModel) {
|
||||||
log.warn("It is unusual for one WrapperModel ({}) to wrap another WrapperModel ({})",
|
initAdapterModel(solrResourceLoader, (AdapterModel)model, managedFeatureStore);
|
||||||
wrapperModel.getName(),
|
|
||||||
model.getName());
|
|
||||||
initWrapperModel(solrResourceLoader, (WrapperModel)model, managedFeatureStore);
|
|
||||||
}
|
}
|
||||||
wrapperModel.updateModel(model);
|
wrapperModel.updateModel(model);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.model;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.BufferedWriter;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.io.OutputStreamWriter;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.search.Explanation;
|
||||||
|
import org.apache.solr.client.solrj.SolrQuery;
|
||||||
|
import org.apache.solr.core.SolrResourceLoader;
|
||||||
|
import org.apache.solr.ltr.TestRerankBase;
|
||||||
|
import org.apache.solr.ltr.feature.Feature;
|
||||||
|
import org.apache.solr.ltr.feature.FieldValueFeature;
|
||||||
|
import org.apache.solr.ltr.norm.Normalizer;
|
||||||
|
import org.apache.solr.ltr.store.rest.ManagedModelStore;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class TestAdapterModel extends TestRerankBase {
|
||||||
|
|
||||||
|
private static int numDocs = 0;
|
||||||
|
private static float scoreValue;
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void setupBeforeClass() throws Exception {
|
||||||
|
|
||||||
|
setuptest(false);
|
||||||
|
|
||||||
|
for (int ii=1; ii<=random().nextInt(10); ++ii) {
|
||||||
|
String id = Integer.toString(ii);
|
||||||
|
assertU(adoc("id", id, "popularity", ii+"00"));
|
||||||
|
++numDocs;
|
||||||
|
}
|
||||||
|
assertU(commit());
|
||||||
|
|
||||||
|
loadFeature("popularity", FieldValueFeature.class.getName(), "test", "{\"field\":\"popularity\"}");
|
||||||
|
|
||||||
|
scoreValue = random().nextFloat();
|
||||||
|
final File scoreValueFile = new File(tmpConfDir, "scoreValue.txt");
|
||||||
|
try (BufferedWriter writer = new BufferedWriter(
|
||||||
|
new OutputStreamWriter(new FileOutputStream(scoreValueFile), StandardCharsets.UTF_8))) {
|
||||||
|
writer.write(Float.toString(scoreValue));
|
||||||
|
}
|
||||||
|
scoreValueFile.deleteOnExit();
|
||||||
|
|
||||||
|
final String modelJson = getModelInJson(
|
||||||
|
"answerModel",
|
||||||
|
CustomModel.class.getName(),
|
||||||
|
new String[] { "popularity" },
|
||||||
|
"test",
|
||||||
|
"{\"answerFileName\":\"" + scoreValueFile.getName() + "\"}");
|
||||||
|
assertJPut(ManagedModelStore.REST_END_POINT, modelJson, "/responseHeader/status==0");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test() throws Exception {
|
||||||
|
final int rows = random().nextInt(numDocs+1); // 0..numDocs
|
||||||
|
final SolrQuery query = new SolrQuery("*:*");
|
||||||
|
query.setRows(rows);
|
||||||
|
query.setFields("*,score");
|
||||||
|
query.add("rq", "{!ltr model=answerModel}");
|
||||||
|
final String[] tests = new String[rows];
|
||||||
|
for (int ii=0; ii<rows; ++ii) {
|
||||||
|
tests[ii] = "/response/docs/["+ii+"]/score=="+scoreValue;
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class CustomModel extends AdapterModel {
|
||||||
|
|
||||||
|
private String answerFileName;
|
||||||
|
private float answerValue;
|
||||||
|
|
||||||
|
public CustomModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
|
||||||
|
List<Feature> allFeatures, Map<String,Object> params) {
|
||||||
|
super(name, features, norms, featureStoreName, allFeatures, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAnswerFileName(String answerFileName) {
|
||||||
|
this.answerFileName = answerFileName;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() throws ModelException {
|
||||||
|
super.validate();
|
||||||
|
if (answerFileName == null) {
|
||||||
|
throw new ModelException("no answerFileName configured for model "+name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void init(SolrResourceLoader solrResourceLoader) throws ModelException {
|
||||||
|
super.init(solrResourceLoader);
|
||||||
|
try (
|
||||||
|
InputStream is = solrResourceLoader.openResource(answerFileName);
|
||||||
|
InputStreamReader isr = new InputStreamReader(is, StandardCharsets.UTF_8);
|
||||||
|
BufferedReader br = new BufferedReader(isr)
|
||||||
|
) {
|
||||||
|
answerValue = Float.parseFloat(br.readLine());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new ModelException("Failed to get the answerValue from the given answerFileName (" + answerFileName + ")", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score(float[] modelFeatureValuesNormalized) {
|
||||||
|
return answerValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Explanation explain(LeafReaderContext context, int doc, float finalScore,
|
||||||
|
List<Explanation> featureExplanations) {
|
||||||
|
return Explanation.match(finalScore, toString()
|
||||||
|
+ " model, always returns "+Float.toString(answerValue)+".");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -88,6 +88,7 @@ Feature selection and model training take place offline and outside Solr. The lt
|
||||||
|Linear |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LinearModel.html[LinearModel] |RankSVM, Pranking
|
|Linear |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LinearModel.html[LinearModel] |RankSVM, Pranking
|
||||||
|Multiple Additive Trees |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html[MultipleAdditiveTreesModel] |LambdaMART, Gradient Boosted Regression Trees (GBRT)
|
|Multiple Additive Trees |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html[MultipleAdditiveTreesModel] |LambdaMART, Gradient Boosted Regression Trees (GBRT)
|
||||||
|(wrapper) |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/DefaultWrapperModel.html[DefaultWrapperModel] |(not applicable)
|
|(wrapper) |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/DefaultWrapperModel.html[DefaultWrapperModel] |(not applicable)
|
||||||
|
|(custom) |(custom class extending {solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/AdapterModel.html[AdapterModel]) |(not applicable)
|
||||||
|(custom) |(custom class extending {solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel]) |(not applicable)
|
|(custom) |(custom class extending {solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel]) |(not applicable)
|
||||||
|===
|
|===
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue