mirror of https://github.com/apache/lucene.git
SOLR-11250: A new DefaultWrapperModel class for loading of large and/or externally stored LTRScoringModel definitions. (Yuki Yano, shalin, Christine Poerschke)
This commit is contained in:
parent
a06e685642
commit
64b3a5bb4b
|
@ -99,6 +99,9 @@ New Features
|
|||
|
||||
* SOLR-11202: Implement a set-property command for AutoScaling API. (ab, shalin)
|
||||
|
||||
* SOLR-11250: A new DefaultWrapperModel class for loading of large and/or externally stored
|
||||
LTRScoringModel definitions. (Yuki Yano, shalin, Christine Poerschke)
|
||||
|
||||
Bug Fixes
|
||||
----------------------
|
||||
|
||||
|
|
|
@ -25,6 +25,11 @@
|
|||
|
||||
<import file="../contrib-build.xml"/>
|
||||
|
||||
<path id="test.classpath">
|
||||
<path refid="solr.test.base.classpath"/>
|
||||
<fileset dir="${test.lib.dir}" includes="*.jar"/>
|
||||
</path>
|
||||
|
||||
<target name="compile-core" depends=" solr-contrib-build.compile-core"/>
|
||||
|
||||
</project>
|
||||
|
|
|
@ -24,9 +24,10 @@
|
|||
</configurations>
|
||||
|
||||
<dependencies>
|
||||
|
||||
|
||||
<dependency org="org.slf4j" name="jcl-over-slf4j" rev="${/org.slf4j/jcl-over-slf4j}" conf="test"/>
|
||||
<dependency org="org.mockito" name="mockito-core" rev="${/org.mockito/mockito-core}" conf="test"/>
|
||||
<dependency org="net.bytebuddy" name="byte-buddy" rev="${/net.bytebuddy/byte-buddy}" conf="test"/>
|
||||
<dependency org="org.objenesis" name="objenesis" rev="${/org.objenesis/objenesis}" conf="test"/>
|
||||
<exclude org="*" ext="*" matcher="regexp" type="${ivy.exclude.types}"/>
|
||||
</dependencies>
|
||||
</ivy-module>
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.model;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.Reader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.solr.core.SolrResourceLoader;
|
||||
import org.apache.solr.ltr.feature.Feature;
|
||||
import org.apache.solr.ltr.norm.Normalizer;
|
||||
import org.noggit.JSONParser;
|
||||
import org.noggit.ObjectBuilder;
|
||||
|
||||
/**
|
||||
* A scoring model that fetches the wrapped model from {@link SolrResourceLoader}.
|
||||
*
|
||||
* <p>This model uses {@link SolrResourceLoader#openResource(String)} for fetching the wrapped model.
|
||||
* If you give a relative path for {@code params/resource}, this model will try to load the wrapped model from
|
||||
* the instance directory (i.e. ${solr.solr.home}). Otherwise, seek through classpaths.
|
||||
*
|
||||
* <p>Example configuration:
|
||||
* <pre>{
|
||||
"class": "org.apache.solr.ltr.model.DefaultWrapperModel",
|
||||
"name": "myWrapperModelName",
|
||||
"params": {
|
||||
"resource": "models/myModel.json"
|
||||
}
|
||||
}</pre>
|
||||
*
|
||||
* @see SolrResourceLoader#openResource(String)
|
||||
*/
|
||||
public class DefaultWrapperModel extends WrapperModel {
|
||||
|
||||
private String resource;
|
||||
|
||||
public DefaultWrapperModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
|
||||
List<Feature> allFeatures, Map<String, Object> params) {
|
||||
super(name, features, norms, featureStoreName, allFeatures, params);
|
||||
}
|
||||
|
||||
public void setResource(String resource) {
|
||||
this.resource = resource;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void validate() throws ModelException {
|
||||
super.validate();
|
||||
if (resource == null) {
|
||||
throw new ModelException("no resource configured for model "+name);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> fetchModelMap() throws ModelException {
|
||||
Map<String, Object> modelMapObj;
|
||||
try (InputStream in = openInputStream()) {
|
||||
modelMapObj = parseInputStream(in);
|
||||
} catch (IOException e) {
|
||||
throw new ModelException("Failed to fetch the wrapper model from given resource (" + resource + ")", e);
|
||||
}
|
||||
return modelMapObj;
|
||||
}
|
||||
|
||||
protected InputStream openInputStream() throws IOException {
|
||||
return solrResourceLoader.openResource(resource);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
protected Map<String, Object> parseInputStream(InputStream in) throws IOException {
|
||||
try (Reader reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8))) {
|
||||
return (Map<String, Object>) new ObjectBuilder(new JSONParser(reader)).getVal();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
|
||||
sb.append("(name=").append(getName());
|
||||
sb.append(",resource=").append(resource);
|
||||
sb.append(",model=(").append(model.toString()).append(")");
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.model;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.solr.core.SolrResourceLoader;
|
||||
import org.apache.solr.ltr.feature.Feature;
|
||||
import org.apache.solr.ltr.norm.Normalizer;
|
||||
|
||||
/**
|
||||
* A scoring model that wraps the other model.
|
||||
*
|
||||
* <p>This model loads a model from an external resource during the initialization.
|
||||
* The way of fetching the wrapped model is depended on
|
||||
* the implementation of {@link WrapperModel#fetchModelMap()}.
|
||||
*
|
||||
* <p>This model doesn't hold the actual parameters of the wrapped model,
|
||||
* thus it can manage large models which are difficult to upload to ZooKeeper.
|
||||
*
|
||||
* <p>Example configuration:
|
||||
* <pre>{
|
||||
"class": "...",
|
||||
"name": "myModelName",
|
||||
"params": {
|
||||
...
|
||||
}
|
||||
}</pre>
|
||||
*
|
||||
* <p>NOTE: no "features" are configured in the wrapper model
|
||||
* because the wrapped model's features will be used instead.
|
||||
* Also note that if a "store" is configured for the wrapper
|
||||
* model then it must match the "store" of the wrapped model.
|
||||
*/
|
||||
public abstract class WrapperModel extends LTRScoringModel {
|
||||
|
||||
protected SolrResourceLoader solrResourceLoader;
|
||||
protected LTRScoringModel model;
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = super.hashCode();
|
||||
result = prime * result + ((model == null) ? 0 : model.hashCode());
|
||||
result = prime * result + ((solrResourceLoader == null) ? 0 : solrResourceLoader.hashCode());
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) return true;
|
||||
if (!super.equals(obj)) return false;
|
||||
if (getClass() != obj.getClass()) return false;
|
||||
WrapperModel other = (WrapperModel) obj;
|
||||
if (model == null) {
|
||||
if (other.model != null) return false;
|
||||
} else if (!model.equals(other.model)) return false;
|
||||
if (solrResourceLoader == null) {
|
||||
if (other.solrResourceLoader != null) return false;
|
||||
} else if (!solrResourceLoader.equals(other.solrResourceLoader)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
public WrapperModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
|
||||
List<Feature> allFeatures, Map<String, Object> params) {
|
||||
super(name, features, norms, featureStoreName, allFeatures, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void validate() throws ModelException {
|
||||
if (!features.isEmpty()) {
|
||||
throw new ModelException("features must be empty for the wrapper model " + name);
|
||||
}
|
||||
if (!norms.isEmpty()) {
|
||||
throw new ModelException("norms must be empty for the wrapper model " + name);
|
||||
}
|
||||
|
||||
if (model != null) {
|
||||
super.validate();
|
||||
model.validate();
|
||||
// check feature store names match
|
||||
final String wrappedFeatureStoreName = model.getFeatureStoreName();
|
||||
if (wrappedFeatureStoreName == null || !wrappedFeatureStoreName.equals(this.getFeatureStoreName())) {
|
||||
throw new ModelException("wrapper feature store name ("+this.getFeatureStoreName() +")"
|
||||
+ " must match the "
|
||||
+ "wrapped feature store name ("+wrappedFeatureStoreName+")");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void setSolrResourceLoader(SolrResourceLoader solrResourceLoader) {
|
||||
this.solrResourceLoader = solrResourceLoader;
|
||||
}
|
||||
|
||||
public void updateModel(LTRScoringModel model) {
|
||||
this.model = model;
|
||||
validate();
|
||||
}
|
||||
|
||||
/*
|
||||
* The child classes must implement how to fetch the definition of the wrapped model.
|
||||
*/
|
||||
public abstract Map<String, Object> fetchModelMap() throws ModelException;
|
||||
|
||||
@Override
|
||||
public List<Normalizer> getNorms() {
|
||||
return model.getNorms();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Feature> getFeatures() {
|
||||
return model.getFeatures();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Feature> getAllFeatures() {
|
||||
return model.getAllFeatures();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(float[] modelFeatureValuesNormalized) {
|
||||
return model.score(modelFeatureValuesNormalized);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc, float finalScore,
|
||||
List<Explanation> featureExplanations) {
|
||||
return model.explain(context, doc, finalScore, featureExplanations);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void normalizeFeaturesInPlace(float[] modelFeatureValues) {
|
||||
model.normalizeFeaturesInPlace(modelFeatureValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation getNormalizerExplanation(Explanation e, int idx) {
|
||||
return model.getNormalizerExplanation(e, idx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
|
||||
sb.append("(name=").append(getName());
|
||||
sb.append(",model=(").append(model.toString()).append(")");
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
}
|
|
@ -27,6 +27,7 @@ import org.apache.solr.common.util.NamedList;
|
|||
import org.apache.solr.core.SolrCore;
|
||||
import org.apache.solr.core.SolrResourceLoader;
|
||||
import org.apache.solr.ltr.feature.Feature;
|
||||
import org.apache.solr.ltr.model.WrapperModel;
|
||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||
import org.apache.solr.ltr.model.ModelException;
|
||||
import org.apache.solr.ltr.norm.IdentityNormalizer;
|
||||
|
@ -231,7 +232,7 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
|||
}
|
||||
}
|
||||
|
||||
return LTRScoringModel.getInstance(solrResourceLoader,
|
||||
final LTRScoringModel ltrScoringModel = LTRScoringModel.getInstance(solrResourceLoader,
|
||||
(String) modelMap.get(CLASS_KEY), // modelClassName
|
||||
(String) modelMap.get(NAME_KEY), // modelName
|
||||
features,
|
||||
|
@ -239,6 +240,28 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
|||
featureStore.getName(),
|
||||
featureStore.getFeatures(),
|
||||
(Map<String,Object>) modelMap.get(PARAMS_KEY));
|
||||
|
||||
if (ltrScoringModel instanceof WrapperModel) {
|
||||
initWrapperModel(solrResourceLoader, (WrapperModel)ltrScoringModel, managedFeatureStore);
|
||||
}
|
||||
|
||||
return ltrScoringModel;
|
||||
}
|
||||
|
||||
private static void initWrapperModel(SolrResourceLoader solrResourceLoader,
|
||||
WrapperModel wrapperModel, ManagedFeatureStore managedFeatureStore) {
|
||||
wrapperModel.setSolrResourceLoader(solrResourceLoader);
|
||||
final LTRScoringModel model = fromLTRScoringModelMap(
|
||||
solrResourceLoader,
|
||||
wrapperModel.fetchModelMap(),
|
||||
managedFeatureStore);
|
||||
if (model instanceof WrapperModel) {
|
||||
log.warn("It is unusual for one WrapperModel ({}) to wrap another WrapperModel ({})",
|
||||
wrapperModel.getName(),
|
||||
model.getName());
|
||||
initWrapperModel(solrResourceLoader, (WrapperModel)model, managedFeatureStore);
|
||||
}
|
||||
wrapperModel.updateModel(model);
|
||||
}
|
||||
|
||||
private static LinkedHashMap<String,Object> toLTRScoringModelMap(LTRScoringModel model) {
|
||||
|
@ -249,10 +272,12 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
|||
modelMap.put(STORE_KEY, model.getFeatureStoreName());
|
||||
|
||||
final List<Map<String,Object>> features = new ArrayList<>();
|
||||
final List<Feature> featuresList = model.getFeatures();
|
||||
final List<Normalizer> normsList = model.getNorms();
|
||||
for (int ii=0; ii<featuresList.size(); ++ii) {
|
||||
features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii)));
|
||||
if (!(model instanceof WrapperModel)) {
|
||||
final List<Feature> featuresList = model.getFeatures();
|
||||
final List<Normalizer> normsList = model.getNorms();
|
||||
for (int ii = 0; ii < featuresList.size(); ++ii) {
|
||||
features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii)));
|
||||
}
|
||||
}
|
||||
modelMap.put(FEATURES_KEY, features);
|
||||
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
<directoryFactory name="DirectoryFactory"
|
||||
class="${solr.directoryFactory:solr.RAMDirectoryFactory}" />
|
||||
|
||||
<!-- for use with the DefaultWrapperModel class -->
|
||||
<lib dir="${solr.solr.home:.}/models" />
|
||||
|
||||
<schemaFactory class="ClassicIndexSchemaFactory" />
|
||||
|
||||
<requestDispatcher>
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.model;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.solr.client.solrj.SolrQuery;
|
||||
import org.apache.solr.ltr.TestRerankBase;
|
||||
import org.apache.solr.ltr.feature.Feature;
|
||||
import org.apache.solr.ltr.feature.FieldValueFeature;
|
||||
import org.apache.solr.ltr.feature.ValueFeature;
|
||||
import org.apache.solr.ltr.store.rest.ManagedModelStore;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestDefaultWrapperModel extends TestRerankBase {
|
||||
|
||||
final private static String featureStoreName = "test";
|
||||
private static String baseModelJson = null;
|
||||
private static File baseModelFile = null;
|
||||
|
||||
static List<Feature> features = null;
|
||||
|
||||
@BeforeClass
|
||||
public static void setupBeforeClass() throws Exception {
|
||||
setuptest(false);
|
||||
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", "1"));
|
||||
assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", "2"));
|
||||
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", "3"));
|
||||
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", "4"));
|
||||
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", "5"));
|
||||
assertU(commit());
|
||||
|
||||
loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), "test", "{\"field\":\"popularity\"}");
|
||||
loadFeature("const", ValueFeature.class.getCanonicalName(), "test", "{\"value\":5}");
|
||||
features = new ArrayList<>();
|
||||
features.add(getManagedFeatureStore().getFeatureStore("test").get("popularity"));
|
||||
features.add(getManagedFeatureStore().getFeatureStore("test").get("const"));
|
||||
|
||||
baseModelJson = getModelInJson("linear", LinearModel.class.getCanonicalName(),
|
||||
new String[] {"popularity", "const"},
|
||||
featureStoreName,
|
||||
"{\"weights\":{\"popularity\":-1.0, \"const\":1.0}}");
|
||||
// prepare the base model as a file resource
|
||||
baseModelFile = new File(tmpConfDir, "baseModel.json");
|
||||
try (BufferedWriter writer = new BufferedWriter(
|
||||
new OutputStreamWriter(new FileOutputStream(baseModelFile), StandardCharsets.UTF_8))) {
|
||||
writer.write(baseModelJson);
|
||||
}
|
||||
baseModelFile.deleteOnExit();
|
||||
}
|
||||
|
||||
private static String getDefaultWrapperModelInJson(String wrapperModelName, String[] features, String params) {
|
||||
return getModelInJson(wrapperModelName, DefaultWrapperModel.class.getCanonicalName(),
|
||||
features, featureStoreName, params);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadModelFromResource() throws Exception {
|
||||
String wrapperModelJson = getDefaultWrapperModelInJson("fileWrapper",
|
||||
new String[0],
|
||||
"{\"resource\":\"" + baseModelFile.getName() + "\"}");
|
||||
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0");
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("{!func}pow(popularity,2)");
|
||||
query.add("rows", "3");
|
||||
query.add("fl", "*,score");
|
||||
query.add("rq", "{!ltr reRankDocs=3 model=fileWrapper}");
|
||||
assertJQ("/query" + query.toQueryString(),
|
||||
"/response/docs/[0]/id==\"3\"", "/response/docs/[0]/score==2.0",
|
||||
"/response/docs/[1]/id==\"4\"", "/response/docs/[1]/score==1.0",
|
||||
"/response/docs/[2]/id==\"5\"", "/response/docs/[2]/score==0.0");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadNestedWrapperModel() throws Exception {
|
||||
String otherWrapperModelJson = getDefaultWrapperModelInJson("otherNestedWrapper",
|
||||
new String[0],
|
||||
"{\"resource\":\"" + baseModelFile.getName() + "\"}");
|
||||
File otherWrapperModelFile = new File(tmpConfDir, "nestedWrapperModel.json");
|
||||
try (BufferedWriter writer = new BufferedWriter(
|
||||
new OutputStreamWriter(new FileOutputStream(otherWrapperModelFile), StandardCharsets.UTF_8))) {
|
||||
writer.write(otherWrapperModelJson);
|
||||
}
|
||||
|
||||
String wrapperModelJson = getDefaultWrapperModelInJson("nestedWrapper",
|
||||
new String[0],
|
||||
"{\"resource\":\"" + otherWrapperModelFile.getName() + "\"}");
|
||||
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0");
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("{!func}pow(popularity,2)");
|
||||
query.add("rows", "3");
|
||||
query.add("fl", "*,score");
|
||||
query.add("rq", "{!ltr reRankDocs=3 model=nestedWrapper}");
|
||||
assertJQ("/query" + query.toQueryString(),
|
||||
"/response/docs/[0]/id==\"3\"", "/response/docs/[0]/score==2.0",
|
||||
"/response/docs/[1]/id==\"4\"", "/response/docs/[1]/score==1.0",
|
||||
"/response/docs/[2]/id==\"5\"", "/response/docs/[2]/score==0.0");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadModelFromUnknownResource() throws Exception {
|
||||
String wrapperModelJson = getDefaultWrapperModelInJson("unknownWrapper",
|
||||
new String[0],
|
||||
"{\"resource\":\"unknownModel.json\"}");
|
||||
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson,
|
||||
"/responseHeader/status==400",
|
||||
"/error/msg==\"org.apache.solr.ltr.model.ModelException: "
|
||||
+ "Failed to fetch the wrapper model from given resource (unknownModel.json)\"");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadModelWithEmptyParams() throws Exception {
|
||||
String wrapperModelJson = getDefaultWrapperModelInJson("invalidWrapper",
|
||||
new String[0],
|
||||
"{}");
|
||||
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson,
|
||||
"/responseHeader/status==400",
|
||||
"/error/msg==\"org.apache.solr.ltr.model.ModelException: "
|
||||
+ "no resource configured for model invalidWrapper\"");
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,290 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.model;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.solr.ltr.TestRerankBase;
|
||||
import org.apache.solr.ltr.feature.Feature;
|
||||
import org.apache.solr.ltr.feature.ValueFeature;
|
||||
import org.apache.solr.ltr.norm.IdentityNormalizer;
|
||||
import org.apache.solr.ltr.norm.Normalizer;
|
||||
import org.apache.solr.ltr.store.FeatureStore;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
public class TestWrapperModel extends TestRerankBase {
|
||||
|
||||
private static class StubWrapperModel extends WrapperModel {
|
||||
|
||||
private StubWrapperModel(String name) {
|
||||
this(name, Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
private StubWrapperModel(String name, List<Feature> features, List<Normalizer> norms) {
|
||||
super(name, features, norms, FeatureStore.DEFAULT_FEATURE_STORE_NAME, features, Collections.emptyMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> fetchModelMap() throws ModelException {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static LTRScoringModel createMockWrappedModel(String featureStoreName,
|
||||
List<Feature> features, List<Normalizer> norms) {
|
||||
LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class);
|
||||
Mockito.doReturn(featureStoreName).when(wrappedModel).getFeatureStoreName();
|
||||
Mockito.doReturn(features).when(wrappedModel).getFeatures();
|
||||
Mockito.doReturn(norms).when(wrappedModel).getNorms();
|
||||
return wrappedModel;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testValidate() throws Exception {
|
||||
WrapperModel wrapperModel = new StubWrapperModel("testModel");
|
||||
try {
|
||||
wrapperModel.validate();
|
||||
} catch (ModelException e) {
|
||||
fail("Validation must succeed if no wrapped model is set");
|
||||
}
|
||||
|
||||
// wrapper model with features
|
||||
WrapperModel wrapperModelWithFeatures = new StubWrapperModel("testModel",
|
||||
Collections.singletonList(new ValueFeature("val", Collections.emptyMap())), Collections.emptyList());
|
||||
try {
|
||||
wrapperModelWithFeatures.validate();
|
||||
fail("Validation must fail if features of the wrapper model isn't empty");
|
||||
} catch (ModelException e) {
|
||||
assertEquals("features must be empty for the wrapper model testModel", e.getMessage());
|
||||
}
|
||||
|
||||
// wrapper model with norms
|
||||
WrapperModel wrapperModelWithNorms = new StubWrapperModel("testModel",
|
||||
Collections.emptyList(), Collections.singletonList(IdentityNormalizer.INSTANCE));
|
||||
try {
|
||||
wrapperModelWithNorms.validate();
|
||||
fail("Validation must fail if norms of the wrapper model isn't empty");
|
||||
} catch (ModelException e) {
|
||||
assertEquals("norms must be empty for the wrapper model testModel", e.getMessage());
|
||||
}
|
||||
|
||||
assumeWorkingMockito();
|
||||
|
||||
// update valid model
|
||||
{
|
||||
LTRScoringModel wrappedModel =
|
||||
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
|
||||
Arrays.asList(
|
||||
new ValueFeature("v1", Collections.emptyMap()),
|
||||
new ValueFeature("v2", Collections.emptyMap())),
|
||||
Arrays.asList(
|
||||
IdentityNormalizer.INSTANCE,
|
||||
IdentityNormalizer.INSTANCE)
|
||||
);
|
||||
try {
|
||||
wrapperModel.updateModel(wrappedModel);
|
||||
} catch (ModelException e) {
|
||||
fail("Validation must succeed if the wrapped model is valid");
|
||||
}
|
||||
}
|
||||
|
||||
// update invalid model (feature store mismatch)
|
||||
{
|
||||
LTRScoringModel wrappedModel =
|
||||
createMockWrappedModel("wrappedFeatureStore",
|
||||
Arrays.asList(
|
||||
new ValueFeature("v1", Collections.emptyMap()),
|
||||
new ValueFeature("v2", Collections.emptyMap())),
|
||||
Arrays.asList(
|
||||
IdentityNormalizer.INSTANCE,
|
||||
IdentityNormalizer.INSTANCE)
|
||||
);
|
||||
try {
|
||||
wrapperModel.updateModel(wrappedModel);
|
||||
fail("Validation must fail if wrapped model feature store differs from wrapper model feature store");
|
||||
} catch (ModelException e) {
|
||||
assertEquals("wrapper feature store name (_DEFAULT_) must match the wrapped feature store name (wrappedFeatureStore)", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
// update invalid model (no features)
|
||||
{
|
||||
LTRScoringModel wrappedModel =
|
||||
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
|
||||
Collections.emptyList(),
|
||||
Arrays.asList(
|
||||
IdentityNormalizer.INSTANCE,
|
||||
IdentityNormalizer.INSTANCE)
|
||||
);
|
||||
try {
|
||||
wrapperModel.updateModel(wrappedModel);
|
||||
fail("Validation must fail if the wrapped model is invalid");
|
||||
} catch (ModelException e) {
|
||||
assertEquals("no features declared for model testModel", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
// update invalid model (no norms)
|
||||
{
|
||||
LTRScoringModel wrappedModel =
|
||||
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
|
||||
Arrays.asList(
|
||||
new ValueFeature("v1", Collections.emptyMap()),
|
||||
new ValueFeature("v2", Collections.emptyMap())),
|
||||
Collections.emptyList()
|
||||
);
|
||||
try {
|
||||
wrapperModel.updateModel(wrappedModel);
|
||||
fail("Validation must fail if the wrapped model is invalid");
|
||||
} catch (ModelException e) {
|
||||
assertEquals("counted 2 features and 0 norms in model testModel", e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMethodOverridesAndDelegation() throws Exception {
|
||||
assumeWorkingMockito();
|
||||
final int overridableMethodCount = testOverwrittenMethods();
|
||||
final int methodCount = testDelegateMethods();
|
||||
assertEquals("method count mismatch", overridableMethodCount, methodCount);
|
||||
}
|
||||
|
||||
private int testOverwrittenMethods() throws Exception {
|
||||
int overridableMethodCount = 0;
|
||||
for (final Method superClassMethod : LTRScoringModel.class.getDeclaredMethods()) {
|
||||
final int modifiers = superClassMethod.getModifiers();
|
||||
if (Modifier.isFinal(modifiers)) continue;
|
||||
if (Modifier.isStatic(modifiers)) continue;
|
||||
|
||||
++overridableMethodCount;
|
||||
if (Arrays.asList(
|
||||
"getName", // the wrapper model's name is its own name i.e. _not_ the name of the wrapped model
|
||||
"getFeatureStoreName", // wrapper and wrapped model feature store should match, so need not override
|
||||
"getParams" // the wrapper model's params are its own params i.e. _not_ the params of the wrapped model
|
||||
).contains(superClassMethod.getName())) {
|
||||
try {
|
||||
final Method subClassMethod = WrapperModel.class.getDeclaredMethod(
|
||||
superClassMethod.getName(),
|
||||
superClassMethod.getParameterTypes());
|
||||
fail(WrapperModel.class + " need not override\n'" + superClassMethod + "'"
|
||||
+ " but it does override\n'" + subClassMethod + "'");
|
||||
} catch (NoSuchMethodException e) {
|
||||
// ok
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
final Method subClassMethod = WrapperModel.class.getDeclaredMethod(
|
||||
superClassMethod.getName(),
|
||||
superClassMethod.getParameterTypes());
|
||||
assertEquals("getReturnType() difference",
|
||||
superClassMethod.getReturnType(),
|
||||
subClassMethod.getReturnType());
|
||||
} catch (NoSuchMethodException e) {
|
||||
fail(WrapperModel.class + " needs to override '" + superClassMethod + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
return overridableMethodCount;
|
||||
}
|
||||
|
||||
private int testDelegateMethods() throws Exception {
|
||||
int methodCount = 0;
|
||||
WrapperModel wrapperModel = Mockito.spy(new StubWrapperModel("testModel"));
|
||||
|
||||
// ignore validate in this test case
|
||||
Mockito.doNothing().when(wrapperModel).validate();
|
||||
++methodCount;
|
||||
|
||||
LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class);
|
||||
wrapperModel.updateModel(wrappedModel);
|
||||
|
||||
// cannot be stubbed or verified
|
||||
++methodCount; // toString
|
||||
++methodCount; // hashCode
|
||||
++methodCount; // equals
|
||||
|
||||
// getFeatureStoreName : not delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getFeatureStoreName();
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(0)).getFeatureStoreName();
|
||||
|
||||
// getName : not delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getName();
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(0)).getName();
|
||||
|
||||
// getParams : not delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getParams();
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(0)).getParams();
|
||||
|
||||
// getNorms : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getNorms();
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).getNorms();
|
||||
|
||||
// getFeatures : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getFeatures();
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).getFeatures();
|
||||
|
||||
// getAllFeatures : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getAllFeatures();
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).getAllFeatures();
|
||||
|
||||
// score : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.score(null);
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).score(null);
|
||||
|
||||
// normalizeFeaturesInPlace : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.normalizeFeaturesInPlace(null);
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).normalizeFeaturesInPlace(null);
|
||||
|
||||
// getNormalizerExplanation : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.getNormalizerExplanation(null, 0);
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).getNormalizerExplanation(null, 0);
|
||||
|
||||
// explain : delegate
|
||||
Mockito.reset(wrappedModel);
|
||||
wrapperModel.explain(null, 0, 0.0f, null);
|
||||
++methodCount;
|
||||
Mockito.verify(wrappedModel, Mockito.times(1)).explain(null, 0, 0.0f, null);
|
||||
|
||||
return methodCount;
|
||||
}
|
||||
}
|
|
@ -16,12 +16,19 @@
|
|||
*/
|
||||
package org.apache.solr.ltr.store.rest;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.solr.ltr.TestRerankBase;
|
||||
import org.apache.solr.ltr.feature.FieldValueFeature;
|
||||
import org.apache.solr.ltr.feature.ValueFeature;
|
||||
import org.apache.solr.ltr.model.DefaultWrapperModel;
|
||||
import org.apache.solr.ltr.model.LinearModel;
|
||||
import org.apache.solr.ltr.store.FeatureStore;
|
||||
import org.junit.BeforeClass;
|
||||
|
@ -185,4 +192,73 @@ public class TestModelManagerPersistence extends TestRerankBase {
|
|||
assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FeatureStore.DEFAULT_FEATURE_STORE_NAME,
|
||||
"/features/==[]");
|
||||
}
|
||||
|
||||
private static void doWrapperModelPersistenceChecks(String modelName,
|
||||
String featureStoreName, String baseModelFileName) throws Exception {
|
||||
// note that the wrapper and the wrapped model always have the same name
|
||||
assertJQ(ManagedModelStore.REST_END_POINT,
|
||||
// the wrapped model shouldn't be registered
|
||||
"!/models/[1]/name=='"+modelName+"'",
|
||||
// but the wrapper model should be registered
|
||||
"/models/[0]/name=='"+modelName+"'",
|
||||
"/models/[0]/class=='" + DefaultWrapperModel.class.getCanonicalName() + "'",
|
||||
"/models/[0]/store=='" + featureStoreName + "'",
|
||||
// the wrapper model shouldn't contain the definitions of the wrapped model
|
||||
"/models/[0]/features/==[]",
|
||||
// but only its own parameters
|
||||
"/models/[0]/params=={resource:'"+baseModelFileName+"'}");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrapperModelPersistence() throws Exception {
|
||||
final String modelName = "linear";
|
||||
final String FS_NAME = "testWrapper";
|
||||
|
||||
// check whether models and features are empty
|
||||
assertJQ(ManagedModelStore.REST_END_POINT,
|
||||
"/models/==[]");
|
||||
assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME,
|
||||
"/features/==[]");
|
||||
|
||||
// setup features
|
||||
loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), FS_NAME, "{\"field\":\"popularity\"}");
|
||||
loadFeature("const", ValueFeature.class.getCanonicalName(), FS_NAME, "{\"value\":5}");
|
||||
|
||||
// setup base model
|
||||
String baseModelJson = getModelInJson(modelName, LinearModel.class.getCanonicalName(),
|
||||
new String[] {"popularity", "const"}, FS_NAME,
|
||||
"{\"weights\":{\"popularity\":-1.0, \"const\":1.0}}");
|
||||
File baseModelFile = new File(tmpConfDir, "baseModelForPersistence.json");
|
||||
try (BufferedWriter writer = new BufferedWriter(
|
||||
new OutputStreamWriter(new FileOutputStream(baseModelFile), StandardCharsets.UTF_8))) {
|
||||
writer.write(baseModelJson);
|
||||
}
|
||||
baseModelFile.deleteOnExit();
|
||||
|
||||
// setup wrapper model
|
||||
String wrapperModelJson = getModelInJson(modelName, DefaultWrapperModel.class.getCanonicalName(),
|
||||
new String[0], FS_NAME,
|
||||
"{\"resource\":\"" + baseModelFile.getName() + "\"}");
|
||||
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0");
|
||||
doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName());
|
||||
|
||||
// check persistence after reload
|
||||
restTestHarness.reload();
|
||||
doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName());
|
||||
|
||||
// check persistence after restart
|
||||
jetty.stop();
|
||||
jetty.start();
|
||||
doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName());
|
||||
|
||||
// delete test settings
|
||||
restTestHarness.delete(ManagedModelStore.REST_END_POINT + "/" + modelName);
|
||||
restTestHarness.delete(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME);
|
||||
assertJQ(ManagedModelStore.REST_END_POINT,
|
||||
"/models/==[]");
|
||||
assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME,
|
||||
"/features/==[]");
|
||||
|
||||
// NOTE: we don't test the persistence of the deletion here because it's tested in testFilePersistence
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,6 +87,7 @@ Feature selection and model training take place offline and outside Solr. The lt
|
|||
|General form |Class |Specific examples
|
||||
|Linear |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LinearModel.html[LinearModel] |RankSVM, Pranking
|
||||
|Multiple Additive Trees |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html[MultipleAdditiveTreesModel] |LambdaMART, Gradient Boosted Regression Trees (GBRT)
|
||||
|(wrapper) |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/DefaultWrapperModel.html[DefaultWrapperModel] |(not applicable)
|
||||
|(custom) |(custom class extending {solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel]) |(not applicable)
|
||||
|===
|
||||
|
||||
|
@ -509,6 +510,54 @@ To delete the `currentFeatureStore` feature store:
|
|||
curl -XDELETE 'http://localhost:8983/solr/techproducts/schema/feature-store/currentFeatureStore'
|
||||
----
|
||||
|
||||
==== Using large models
|
||||
|
||||
With SolrCloud, large models may fail to upload due to the limitation of ZooKeeper's buffer. In this case, `DefaultWrapperModel` may help you to separate the model definition from uploaded file.
|
||||
|
||||
Assuming that you consider to use a large model placed at `/path/to/models/myModel.json` through `DefaultWrapperModel`.
|
||||
|
||||
[source,json]
|
||||
----
|
||||
{
|
||||
"store" : "largeModelsFeatureStore",
|
||||
"name" : "myModel",
|
||||
"class" : ...,
|
||||
"features" : [
|
||||
...
|
||||
],
|
||||
"params" : {
|
||||
...
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
First, add the directory to Solr's resource paths by <<lib-directives-in-solrconfig.adoc#lib-directives-in-solrconfig,Lib Directives>>:
|
||||
|
||||
[source,xml]
|
||||
----
|
||||
<lib dir="/path/to" regex="models" />
|
||||
----
|
||||
|
||||
Then, configure `DefaultWrapperModel` to wrap `myModel.json`:
|
||||
|
||||
[source,json]
|
||||
----
|
||||
{
|
||||
"store" : "largeModelsFeatureStore",
|
||||
"name" : "myWrapperModel",
|
||||
"class" : "org.apache.solr.ltr.model.DefaultWrapperModel",
|
||||
"params" : {
|
||||
"resource" : "myModel.json"
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
`myModel.json` will be loaded during the initialization and be able to use by specifying `model=myWrapperModel`.
|
||||
|
||||
NOTE: No `"features"` are configured in `myWrapperModel` because the features of the wrapped model (`myModel`) will be used; also note that the `"store"` configured for the wrapper model must match that of the wrapped model i.e. in this example the feature store called `largeModelsFeatureStore` is used.
|
||||
|
||||
CAUTION: `<lib dir="/path/to/models" regex=".*\.json" />` doesn't work as expected in this case, because `SolrResourceLoader` considers given resources as JAR if `<lib />` indicates files.
|
||||
|
||||
=== Applying Changes
|
||||
|
||||
The feature store and the model store are both <<managed-resources.adoc#managed-resources,Managed Resources>>. Changes made to managed resources are not applied to the active Solr components until the Solr collection (or Solr core in single server mode) is reloaded.
|
||||
|
|
Loading…
Reference in New Issue